From 944db46d4fc622c9cbbc6e2fb0ceb34c69911391 Mon Sep 17 00:00:00 2001 From: Dream <42954461+eureka928@users.noreply.github.com> Date: Mon, 30 Mar 2026 04:22:29 -0400 Subject: [PATCH] refactor(api): replace json.loads with Pydantic validation in services layer (#33704) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato --- .../datasets/rag_pipeline/datasource_auth.py | 7 ++-- .../console/workspace/tool_providers.py | 3 +- .../console/workspace/trigger_providers.py | 6 +-- api/services/account_service.py | 16 ++++++-- api/services/message_service.py | 10 +++-- api/services/model_load_balancing_service.py | 11 +++--- api/services/plugin/plugin_migration.py | 37 ++++++++++++------- .../restore_archived_workflow_run.py | 22 +++++++++-- api/services/tools/tools_transform_service.py | 10 +++-- .../trigger_subscription_builder_service.py | 18 ++++++--- api/services/workflow_service.py | 2 +- .../services/plugin/test_oauth_service.py | 17 +++++++++ .../test_restore_archived_workflow_run.py | 3 +- 13 files changed, 114 insertions(+), 48 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 1976a6bc8a..bdf83b991e 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource): if context is None: raise Forbidden("Invalid context_id") - user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + user_id: str = context["user_id"] + tenant_id: str = context["tenant_id"] datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id datasource_provider_service = DatasourceProviderService() @@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource): system_credentials=oauth_client_params, request=request, ) - credential_id = context.get("credential_id") + credential_id: str | None = context.get("credential_id") if credential_id: datasource_provider_service.reauthorize_datasource_oauth_provider( tenant_id=tenant_id, @@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource): name=oauth_response.metadata.get("name") or None, expire_at=oauth_response.expires_at, credentials=dict(oauth_response.credentials), - credential_id=context.get("credential_id"), + credential_id=credential_id, ) else: datasource_provider_service.add_datasource_oauth_provider( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 02eb0adc94..80216915cd 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource): tool_provider = ToolProviderID(provider) plugin_id = tool_provider.plugin_id provider_name = tool_provider.provider_name - user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + user_id: str = context["user_id"] + tenant_id: str = context["tenant_id"] oauth_handler = OAuthHandler() oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 265b6ecd9a..76d64cb97c 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource): provider_id = TriggerProviderID(provider) plugin_id = provider_id.plugin_id provider_name = provider_id.provider_name - user_id = context.get("user_id") - tenant_id = context.get("tenant_id") - subscription_builder_id = context.get("subscription_builder_id") + user_id: str = context["user_id"] + tenant_id: str = context["tenant_id"] + subscription_builder_id: str = context["subscription_builder_id"] # Get OAuth client configuration oauth_client_params = TriggerProviderService.get_oauth_client( diff --git a/api/services/account_service.py b/api/services/account_service.py index bd520f54cf..cc8ef08857 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta from hashlib import sha256 from typing import Any, cast -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from sqlalchemy import func, select from sqlalchemy.orm import Session +from typing_extensions import TypedDict + + +class InvitationData(TypedDict): + account_id: str + email: str + workspace_id: str + + +_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData) from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -1571,7 +1581,7 @@ class RegisterService: @classmethod def get_invitation_by_token( cls, token: str, workspace_id: str | None = None, email: str | None = None - ) -> dict[str, str] | None: + ) -> InvitationData | None: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" @@ -1590,7 +1600,7 @@ class RegisterService: if not data: return None - invitation: dict = json.loads(data) + invitation = _invitation_adapter.validate_json(data) return invitation @classmethod diff --git a/api/services/message_service.py b/api/services/message_service.py index e5389ef659..a04f9cbe01 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,8 +1,8 @@ -import json from collections.abc import Sequence from typing import Union from graphon.model_runtime.entities.model_entities import ModelType +from pydantic import TypeAdapter from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -17,7 +17,7 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating -from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( SQLAlchemyExecutionExtraContentRepository, @@ -31,6 +31,8 @@ from services.errors.message import ( ) from services.workflow_service import WorkflowService +_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict) + def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -286,7 +288,9 @@ class MessageService: .first() ) else: - conversation_override_model_configs = json.loads(conversation.override_model_configs) + conversation_override_model_configs = _app_model_config_adapter.validate_json( + conversation.override_model_configs + ) app_model_config = AppModelConfig( app_id=app_model.id, ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 91cca5cb6d..25de411e43 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,7 +1,6 @@ import json import logging -from json import JSONDecodeError -from typing import Union +from typing import Any, Union from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -168,10 +167,10 @@ class ModelLoadBalancingService: try: if load_balancing_config.encrypted_config: - credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) + credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config) else: credentials = {} - except JSONDecodeError: + except (json.JSONDecodeError, ValueError): credentials = {} # Get provider credential secret variables @@ -256,7 +255,7 @@ class ModelLoadBalancingService: credentials = json.loads(load_balancing_model_config.encrypted_config) else: credentials = {} - except JSONDecodeError: + except (json.JSONDecodeError, ValueError): credentials = {} # Get credential form schemas from model credential schema or provider credential schema @@ -575,7 +574,7 @@ class ModelLoadBalancingService: original_credentials = json.loads(load_balancing_model_config.encrypted_config) else: original_credentials = {} - except JSONDecodeError: + except (json.JSONDecodeError, ValueError): original_credentials = {} # encrypt credentials diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index df5fa3e233..1562d4e696 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -12,7 +12,9 @@ import click import sqlalchemy as sa import tqdm from flask import Flask, current_app +from pydantic import TypeAdapter from sqlalchemy.orm import Session +from typing_extensions import TypedDict from core.agent.entities import AgentToolEntity from core.helper import marketplace @@ -33,6 +35,14 @@ logger = logging.getLogger(__name__) excluded_providers = ["time", "audio", "code", "webscraper"] +class _TenantPluginRecord(TypedDict): + tenant_id: str + plugins: list[str] + + +_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord) + + class PluginMigration: @classmethod def extract_plugins(cls, filepath: str, workers: int): @@ -308,9 +318,8 @@ class PluginMigration: logger.info("Extracting unique plugins from %s", extracted_plugins) with open(extracted_plugins) as f: for line in f: - data = json.loads(line) - new_plugin_ids = data.get("plugins", []) - for plugin_id in new_plugin_ids: + data = _tenant_plugin_adapter.validate_json(line) + for plugin_id in data["plugins"]: if plugin_id not in plugin_ids: plugin_ids.append(plugin_id) @@ -381,21 +390,23 @@ class PluginMigration: Read line by line, and install plugins for each tenant. """ for line in f: - data = json.loads(line) - tenant_id = data.get("tenant_id") - plugin_ids = data.get("plugins", []) - current_not_installed = { - "tenant_id": tenant_id, - "plugin_not_exist": [], - } + data = _tenant_plugin_adapter.validate_json(line) + tenant_id = data["tenant_id"] + plugin_ids = data["plugins"] + plugin_not_exist: list[str] = [] # get plugin unique identifier for plugin_id in plugin_ids: unique_identifier = plugins.get(plugin_id) if unique_identifier: - current_not_installed["plugin_not_exist"].append(plugin_id) + plugin_not_exist.append(plugin_id) - if current_not_installed["plugin_not_exist"]: - not_installed.append(current_not_installed) + if plugin_not_exist: + not_installed.append( + { + "tenant_id": tenant_id, + "plugin_not_exist": plugin_not_exist, + } + ) thread_pool.submit(install, tenant_id, plugin_ids) diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py index 64dad7ba52..c8362738ee 100644 --- a/api/services/retention/workflow_run/restore_archived_workflow_run.py +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -6,7 +6,6 @@ back to the database. """ import io -import json import logging import time import zipfile @@ -17,8 +16,23 @@ from datetime import datetime from typing import Any, cast import click +from pydantic import TypeAdapter from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.engine import CursorResult +from typing_extensions import TypedDict + + +class _TableInfo(TypedDict, total=False): + row_count: int + + +class ArchiveManifest(TypedDict, total=False): + tables: dict[str, _TableInfo] + schema_version: str + + +_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest) + from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from extensions.ext_database import db @@ -239,12 +253,12 @@ class WorkflowRunRestore: return self.workflow_run_repo @staticmethod - def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]: + def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest: try: data = archive.read("manifest.json") except KeyError as e: raise ValueError("manifest.json missing from archive bundle") from e - return json.loads(data.decode("utf-8")) + return _manifest_adapter.validate_json(data) def _restore_table_records( self, @@ -332,7 +346,7 @@ class WorkflowRunRestore: return result - def _get_schema_version(self, manifest: dict[str, Any]) -> str: + def _get_schema_version(self, manifest: ArchiveManifest) -> str: schema_version = manifest.get("schema_version") if not schema_version: logger.warning("Manifest missing schema_version; defaulting to 1.0") diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b276146066..7cd61e3162 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping from typing import Any, Union -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError from yarl import URL from configs import dify_config @@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool]) + class ToolTransformService: _MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10 @@ -53,7 +55,7 @@ class ToolTransformService: if isinstance(icon, str): return json.loads(icon) return icon - except Exception: + except (json.JSONDecodeError, ValueError): return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == ToolProviderType.MCP: return icon @@ -247,8 +249,8 @@ class ToolTransformService: response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive) try: - mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)] - except (ValidationError, json.JSONDecodeError): + mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools) + except (ValidationError, ValueError): mcp_tools = [] # Add additional fields specific to the transform response["id"] = db_provider.server_identifier if not for_list else db_provider.id diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 37f852da3e..889717df72 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -1,4 +1,3 @@ -import json import logging import uuid from collections.abc import Mapping @@ -7,6 +6,7 @@ from datetime import datetime from typing import Any from flask import Request, Response +from pydantic import TypeAdapter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse @@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService logger = logging.getLogger(__name__) +_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog]) + class TriggerSubscriptionBuilderService: """Service for managing trigger providers and credentials""" @@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService: cache_key = cls.encode_cache_key(endpoint_id) subscription_cache = redis_client.get(cache_key) if subscription_cache: - return SubscriptionBuilder.model_validate(json.loads(subscription_cache)) + return SubscriptionBuilder.model_validate_json(subscription_cache) return None @@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService: ) key = f"trigger:subscription:builder:logs:{endpoint_id}" - logs = json.loads(redis_client.get(key) or "[]") - logs.append(log.model_dump(mode="json")) + logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]") + logs.append(log) # Keep last N logs logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :] - redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str)) + redis_client.setex( + key, + cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, + _request_logs_adapter.dump_json(logs), + ) @classmethod def list_logs(cls, endpoint_id: str) -> list[RequestLog]: @@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService: logs_json = redis_client.get(key) if not logs_json: return [] - return [RequestLog.model_validate(log) for log in json.loads(logs_json)] + return _request_logs_adapter.validate_json(logs_json) @classmethod def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b555676704..3b3ee6dd92 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1118,7 +1118,7 @@ class WorkflowService: continue try: payload = json.loads(recipient.recipient_payload) - except Exception: + except (json.JSONDecodeError, ValueError): logger.exception("Failed to parse human input recipient payload for delivery test.") continue email = payload.get("email") diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py index 6511385000..eee65b3a18 100644 --- a/api/tests/unit_tests/services/plugin/test_oauth_service.py +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -93,3 +93,20 @@ class TestUseProxyContext: assert result == stored expected_key = "oauth_proxy_context:valid-id" redis_client.delete.assert_called_once_with(expected_key) + + def test_returns_context_with_credential_id(self): + from extensions.ext_redis import redis_client + + stored = { + "user_id": "u1", + "tenant_id": "t1", + "plugin_id": "p1", + "provider": "github", + "credential_id": "cred-42", + } + redis_client.get.return_value = json.dumps(stored).encode() + + result = OAuthProxyService.use_proxy_context("ctx-with-cred") + + assert result["credential_id"] == "cred-42" + assert result["tenant_id"] == "t1" diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py index 4bfdba87a0..628e4e594d 100644 --- a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -13,6 +13,7 @@ from datetime import datetime from unittest.mock import Mock, create_autospec, patch import pytest +from pydantic import ValidationError from sqlalchemy import Column, Integer, MetaData, String, Table from libs.archive_storage import ArchiveStorageNotConfiguredError @@ -292,7 +293,7 @@ class TestLoadManifestFromZip: zip_buffer.seek(0) with zipfile.ZipFile(zip_buffer, "r") as archive: - with pytest.raises(json.JSONDecodeError): + with pytest.raises(ValidationError): WorkflowRunRestore._load_manifest_from_zip(archive)