mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 11:56:55 +08:00
refactor: replace bare dict with typed annotations in controllers (#35095)
This commit is contained in:
parent
14d83c8bac
commit
b0bf7ca486
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
@ -13,6 +14,14 @@ from services.billing_service import BillingService
|
|||||||
_FALLBACK_LANG = "en-US"
|
_FALLBACK_LANG = "en-US"
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationLangContent(TypedDict, total=False):
|
||||||
|
lang: str
|
||||||
|
title: str
|
||||||
|
subtitle: str
|
||||||
|
body: str
|
||||||
|
titlePicUrl: str
|
||||||
|
|
||||||
|
|
||||||
class NotificationItemDict(TypedDict):
|
class NotificationItemDict(TypedDict):
|
||||||
notification_id: str | None
|
notification_id: str | None
|
||||||
frequency: str | None
|
frequency: str | None
|
||||||
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
|
|||||||
notifications: list[NotificationItemDict]
|
notifications: list[NotificationItemDict]
|
||||||
|
|
||||||
|
|
||||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||||
"""Return the single LangContent for *lang*, falling back to English."""
|
"""Return the single LangContent for *lang*, falling back to English."""
|
||||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
return (
|
||||||
|
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DismissNotificationPayload(BaseModel):
|
class DismissNotificationPayload(BaseModel):
|
||||||
@ -71,7 +82,7 @@ class NotificationApi(Resource):
|
|||||||
|
|
||||||
notifications: list[NotificationItemDict] = []
|
notifications: list[NotificationItemDict] = []
|
||||||
for notification in result.get("notifications") or []:
|
for notification in result.get("notifications") or []:
|
||||||
contents: dict = notification.get("contents") or {}
|
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||||
lang_content = _pick_lang_content(contents, lang)
|
lang_content = _pick_lang_content(contents, lang)
|
||||||
item: NotificationItemDict = {
|
item: NotificationItemDict = {
|
||||||
"notification_id": notification.get("notificationId"),
|
"notification_id": notification.get("notificationId"),
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
from flask import request
|
from flask import request
|
||||||
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
|
|||||||
register_schema_models(console_ns, AccountResponse)
|
register_schema_models(console_ns, AccountResponse)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_account(account) -> dict:
|
def _serialize_account(account) -> dict[str, Any]:
|
||||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from flask import Response
|
from flask import Response
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.variables.input_entities import VariableEntity
|
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||||
|
|
||||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||||
"""Convert raw user input form to VariableEntity objects"""
|
"""Convert raw user input form to VariableEntity objects"""
|
||||||
return [self._create_variable_entity(item) for item in raw_form]
|
return [self._create_variable_entity(item) for item in raw_form]
|
||||||
|
|
||||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||||
"""Create a single VariableEntity from raw form item"""
|
"""Create a single VariableEntity from raw form item"""
|
||||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||||
variable = item[variable_type]
|
try:
|
||||||
|
variable_type = VariableEntityType(variable_type_raw)
|
||||||
|
except ValueError as e:
|
||||||
|
raise MCPRequestError(
|
||||||
|
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||||
|
) from e
|
||||||
|
variable = item[variable_type_raw]
|
||||||
|
|
||||||
return VariableEntity(
|
return VariableEntity(
|
||||||
type=variable_type,
|
type=variable_type,
|
||||||
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
|
|||||||
json_schema=variable.get("json_schema"),
|
json_schema=variable.get("json_schema"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||||
"""Parse and validate MCP request"""
|
"""Parse and validate MCP request"""
|
||||||
try:
|
try:
|
||||||
return mcp_types.ClientRequest.model_validate(args)
|
return mcp_types.ClientRequest.model_validate(args)
|
||||||
|
|||||||
@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
|||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
|
|
||||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||||
"""Marshal a single segment and enrich it with summary content."""
|
"""Marshal a single segment and enrich it with summary content."""
|
||||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||||
segment_dict["summary"] = summary.summary_content if summary else None
|
segment_dict["summary"] = summary.summary_content if summary else None
|
||||||
return segment_dict
|
return segment_dict
|
||||||
|
|
||||||
|
|
||||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||||
segment_ids = [segment.id for segment in segments]
|
segment_ids = [segment.id for segment in segments]
|
||||||
summaries: dict = {}
|
summaries: dict[str, str | None] = {}
|
||||||
if segment_ids:
|
if segment_ids:
|
||||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||||
|
|
||||||
result = []
|
result: list[dict[str, Any]] = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||||
segment_dict["summary"] = summaries.get(segment.id)
|
segment_dict["summary"] = summaries.get(segment.id)
|
||||||
result.append(segment_dict)
|
result.append(segment_dict)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -103,21 +104,23 @@ class PassportResource(Resource):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Decode the enterprise user session from the Authorization header.
|
Decode the enterprise user session from the Authorization header.
|
||||||
"""
|
"""
|
||||||
if not jwt_token:
|
if not jwt_token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
decoded = PassportService().verify(jwt_token)
|
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||||
source = decoded.get("token_source")
|
source = decoded.get("token_source")
|
||||||
if not source or source != "webapp_login_token":
|
if not source or source != "webapp_login_token":
|
||||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
def exchange_token_for_existing_web_user(
|
||||||
|
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Exchange a token for an existing web user session.
|
Exchange a token for an existing web user session.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask_restx import fields, marshal, marshal_with
|
from flask_restx import fields, marshal, marshal_with
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -113,12 +113,12 @@ class AppSiteInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_site(site: Site) -> dict:
|
def serialize_site(site: Site) -> dict[str, Any]:
|
||||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||||
|
|
||||||
|
|
||||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user