refactor: replace bare dict with typed annotations in controllers (#35095)

This commit is contained in:
dataCenter430 2026-04-13 12:19:52 -07:00 committed by GitHub
parent 14d83c8bac
commit b0bf7ca486
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 25 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import TypedDict from typing import TypedDict
from flask import request from flask import request
@ -13,6 +14,14 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US" _FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict): class NotificationItemDict(TypedDict):
notification_id: str | None notification_id: str | None
frequency: str | None frequency: str | None
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict] notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict: def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
"""Return the single LangContent for *lang*, falling back to English.""" """Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
class DismissNotificationPayload(BaseModel): class DismissNotificationPayload(BaseModel):
@ -71,7 +82,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = [] notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []: for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {} contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang) lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = { item: NotificationItemDict = {
"notification_id": notification.get("notificationId"), "notification_id": notification.get("notificationId"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Any, Literal
import pytz import pytz
from flask import request from flask import request
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse) register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict: def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response from flask import Response
from flask_restx import Resource from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
except ValidationError as e: except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects""" """Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form] return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity: def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
"""Create a single VariableEntity from raw form item""" """Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0] variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type] try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
return VariableEntity( return VariableEntity(
type=variable_type, type=variable_type,
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"), json_schema=variable.get("json_schema"),
) )
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request""" """Parse and validate MCP request"""
try: try:
return mcp_types.ClientRequest.model_validate(args) return mcp_types.ClientRequest.model_validate(args)

View File

@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content.""" """Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query).""" """Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments] segment_ids = [segment.id for segment in segments]
summaries: dict = {} summaries: dict[str, str | None] = {}
if segment_ids: if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = [] result: list[dict[str, Any]] = []
for segment in segments: for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id) segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict) result.append(segment_dict)
return result return result

View File

@ -1,5 +1,6 @@
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource from flask_restx import Resource
@ -103,21 +104,23 @@ class PassportResource(Resource):
return response return response
def decode_enterprise_webapp_user_id(jwt_token: str | None): def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
""" """
Decode the enterprise user session from the Authorization header. Decode the enterprise user session from the Authorization header.
""" """
if not jwt_token: if not jwt_token:
return None return None
decoded = PassportService().verify(jwt_token) decoded: dict[str, Any] = PassportService().verify(jwt_token)
source = decoded.get("token_source") source = decoded.get("token_source")
if not source or source != "webapp_login_token": if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
""" """
Exchange a token for an existing web user session. Exchange a token for an existing web user session.
""" """

View File

@ -1,4 +1,4 @@
from typing import cast from typing import Any, cast
from flask_restx import fields, marshal, marshal_with from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select from sqlalchemy import select
@ -113,12 +113,12 @@ class AppSiteInfo:
} }
def serialize_site(site: Site) -> dict: def serialize_site(site: Site) -> dict[str, Any]:
"""Serialize Site model using the same schema as AppSiteApi.""" """Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields)) return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))