diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py index 180167402a..5d46470173 100644 --- a/api/controllers/console/notification.py +++ b/api/controllers/console/notification.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import TypedDict from flask import request @@ -13,6 +14,14 @@ from services.billing_service import BillingService _FALLBACK_LANG = "en-US" +class NotificationLangContent(TypedDict, total=False): + lang: str + title: str + subtitle: str + body: str + titlePicUrl: str + + class NotificationItemDict(TypedDict): notification_id: str | None frequency: str | None @@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict): notifications: list[NotificationItemDict] -def _pick_lang_content(contents: dict, lang: str) -> dict: +def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent: """Return the single LangContent for *lang*, falling back to English.""" - return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) + return ( + contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent()) + ) class DismissNotificationPayload(BaseModel): @@ -71,7 +82,7 @@ class NotificationApi(Resource): notifications: list[NotificationItemDict] = [] for notification in result.get("notifications") or []: - contents: dict = notification.get("contents") or {} + contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {} lang_content = _pick_lang_content(contents, lang) item: NotificationItemDict = { "notification_id": notification.get("notificationId"), diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index af25669ae0..c35006a7ee 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import Literal +from typing import Any, Literal import pytz from flask import request @@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload) register_schema_models(console_ns, AccountResponse) -def _serialize_account(account) -> dict: +def _serialize_account(account) -> dict[str, Any]: return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index d2ce0ea543..8066f198bb 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,7 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource -from graphon.variables.input_entities import VariableEntity +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -158,14 +158,20 @@ class MCPAppApi(Resource): except ValidationError as e: raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") - def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: + def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]: """Convert raw user input form to VariableEntity objects""" return [self._create_variable_entity(item) for item in raw_form] - def _create_variable_entity(self, item: dict) -> VariableEntity: + def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity: """Create a single VariableEntity from raw form item""" - variable_type = item.get("type", "") or list(item.keys())[0] - variable = item[variable_type] + variable_type_raw: str = item.get("type", "") or list(item.keys())[0] + try: + variable_type = VariableEntityType(variable_type_raw) + except ValueError as e: + raise MCPRequestError( + mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}" + ) from e + variable = item[variable_type_raw] return VariableEntity( type=variable_type, @@ -178,7 +184,7 @@ class MCPAppApi(Resource): json_schema=variable.get("json_schema"), ) - def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: + def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification: """Parse and validate MCP request""" try: return mcp_types.ClientRequest.model_validate(args) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 9ad999b93e..971b63577c 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from services.summary_index_service import SummaryIndexService -def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: +def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]: """Marshal a single segment and enrich it with summary content.""" - segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None return segment_dict -def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: +def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]: """Marshal multiple segments and enrich them with summary content (batch query).""" segment_ids = [segment.id for segment in segments] - summaries: dict = {} + summaries: dict[str, str | None] = {} if segment_ids: summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} - result = [] + result: list[dict[str, Any]] = [] for segment in segments: - segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict["summary"] = summaries.get(segment.id) result.append(segment_dict) return result diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 66082893b8..0293df74b0 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,5 +1,6 @@ import uuid from datetime import UTC, datetime, timedelta +from typing import Any from flask import make_response, request from flask_restx import Resource @@ -103,21 +104,23 @@ class PassportResource(Resource): return response -def decode_enterprise_webapp_user_id(jwt_token: str | None): +def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None: """ Decode the enterprise user session from the Authorization header. """ if not jwt_token: return None - decoded = PassportService().verify(jwt_token) + decoded: dict[str, Any] = PassportService().verify(jwt_token) source = decoded.get("token_source") if not source or source != "webapp_login_token": raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") return decoded -def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): +def exchange_token_for_existing_web_user( + app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType +): """ Exchange a token for an existing web user session. """ diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 1a0c6d4252..7d2080dd91 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from flask_restx import fields, marshal, marshal_with from sqlalchemy import select @@ -113,12 +113,12 @@ class AppSiteInfo: } -def serialize_site(site: Site) -> dict: +def serialize_site(site: Site) -> dict[str, Any]: """Serialize Site model using the same schema as AppSiteApi.""" - return cast(dict, marshal(site, AppSiteApi.site_fields)) + return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields)) -def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: +def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]: can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) - return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) + return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))