refactor(api): replace json.loads with Pydantic validation in services layer (#33704)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Dream 2026-03-30 04:22:29 -04:00 committed by GitHub
parent 456684dfc3
commit 944db46d4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 114 additions and 48 deletions

View File

@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource):
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
datasource_provider_service = DatasourceProviderService()
@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource):
system_credentials=oauth_client_params,
request=request,
)
credential_id = context.get("credential_id")
credential_id: str | None = context.get("credential_id")
if credential_id:
datasource_provider_service.reauthorize_datasource_oauth_provider(
tenant_id=tenant_id,
@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource):
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
credential_id=context.get("credential_id"),
credential_id=credential_id,
)
else:
datasource_provider_service.add_datasource_oauth_provider(

View File

@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)

View File

@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
subscription_builder_id: str = context["subscription_builder_id"]
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(

View File

@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
class InvitationData(TypedDict):
account_id: str
email: str
workspace_id: str
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -1571,7 +1581,7 @@ class RegisterService:
@classmethod
def get_invitation_by_token(
cls, token: str, workspace_id: str | None = None, email: str | None = None
) -> dict[str, str] | None:
) -> InvitationData | None:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
@ -1590,7 +1600,7 @@ class RegisterService:
if not data:
return None
invitation: dict = json.loads(data)
invitation = _invitation_adapter.validate_json(data)
return invitation
@classmethod

View File

@ -1,8 +1,8 @@
import json
from collections.abc import Sequence
from typing import Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import TypeAdapter
from sqlalchemy.orm import sessionmaker
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
@ -17,7 +17,7 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
SQLAlchemyExecutionExtraContentRepository,
@ -31,6 +31,8 @@ from services.errors.message import (
)
from services.workflow_service import WorkflowService
_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict)
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -286,7 +288,9 @@ class MessageService:
.first()
)
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
conversation_override_model_configs = _app_model_config_adapter.validate_json(
conversation.override_model_configs
)
app_model_config = AppModelConfig(
app_id=app_model.id,
)

View File

@ -1,7 +1,6 @@
import json
import logging
from json import JSONDecodeError
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -168,10 +167,10 @@ class ModelLoadBalancingService:
try:
if load_balancing_config.encrypted_config:
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
except (json.JSONDecodeError, ValueError):
credentials = {}
# Get provider credential secret variables
@ -256,7 +255,7 @@ class ModelLoadBalancingService:
credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
except (json.JSONDecodeError, ValueError):
credentials = {}
# Get credential form schemas from model credential schema or provider credential schema
@ -575,7 +574,7 @@ class ModelLoadBalancingService:
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
except (json.JSONDecodeError, ValueError):
original_credentials = {}
# encrypt credentials

View File

@ -12,7 +12,9 @@ import click
import sqlalchemy as sa
import tqdm
from flask import Flask, current_app
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from core.agent.entities import AgentToolEntity
from core.helper import marketplace
@ -33,6 +35,14 @@ logger = logging.getLogger(__name__)
excluded_providers = ["time", "audio", "code", "webscraper"]
class _TenantPluginRecord(TypedDict):
tenant_id: str
plugins: list[str]
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int):
@ -308,9 +318,8 @@ class PluginMigration:
logger.info("Extracting unique plugins from %s", extracted_plugins)
with open(extracted_plugins) as f:
for line in f:
data = json.loads(line)
new_plugin_ids = data.get("plugins", [])
for plugin_id in new_plugin_ids:
data = _tenant_plugin_adapter.validate_json(line)
for plugin_id in data["plugins"]:
if plugin_id not in plugin_ids:
plugin_ids.append(plugin_id)
@ -381,21 +390,23 @@ class PluginMigration:
Read line by line, and install plugins for each tenant.
"""
for line in f:
data = json.loads(line)
tenant_id = data.get("tenant_id")
plugin_ids = data.get("plugins", [])
current_not_installed = {
"tenant_id": tenant_id,
"plugin_not_exist": [],
}
data = _tenant_plugin_adapter.validate_json(line)
tenant_id = data["tenant_id"]
plugin_ids = data["plugins"]
plugin_not_exist: list[str] = []
# get plugin unique identifier
for plugin_id in plugin_ids:
unique_identifier = plugins.get(plugin_id)
if unique_identifier:
current_not_installed["plugin_not_exist"].append(plugin_id)
plugin_not_exist.append(plugin_id)
if current_not_installed["plugin_not_exist"]:
not_installed.append(current_not_installed)
if plugin_not_exist:
not_installed.append(
{
"tenant_id": tenant_id,
"plugin_not_exist": plugin_not_exist,
}
)
thread_pool.submit(install, tenant_id, plugin_ids)

View File

@ -6,7 +6,6 @@ back to the database.
"""
import io
import json
import logging
import time
import zipfile
@ -17,8 +16,23 @@ from datetime import datetime
from typing import Any, cast
import click
from pydantic import TypeAdapter
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine import CursorResult
from typing_extensions import TypedDict
class _TableInfo(TypedDict, total=False):
row_count: int
class ArchiveManifest(TypedDict, total=False):
tables: dict[str, _TableInfo]
schema_version: str
_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest)
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from extensions.ext_database import db
@ -239,12 +253,12 @@ class WorkflowRunRestore:
return self.workflow_run_repo
@staticmethod
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest:
try:
data = archive.read("manifest.json")
except KeyError as e:
raise ValueError("manifest.json missing from archive bundle") from e
return json.loads(data.decode("utf-8"))
return _manifest_adapter.validate_json(data)
def _restore_table_records(
self,
@ -332,7 +346,7 @@ class WorkflowRunRestore:
return result
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
def _get_schema_version(self, manifest: ArchiveManifest) -> str:
schema_version = manifest.get("schema_version")
if not schema_version:
logger.warning("Manifest missing schema_version; defaulting to 1.0")

View File

@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping
from typing import Any, Union
from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
from yarl import URL
from configs import dify_config
@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool])
class ToolTransformService:
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
@ -53,7 +55,7 @@ class ToolTransformService:
if isinstance(icon, str):
return json.loads(icon)
return icon
except Exception:
except (json.JSONDecodeError, ValueError):
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP:
return icon
@ -247,8 +249,8 @@ class ToolTransformService:
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
try:
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
except (ValidationError, json.JSONDecodeError):
mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools)
except (ValidationError, ValueError):
mcp_tools = []
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id

View File

@ -1,4 +1,3 @@
import json
import logging
import uuid
from collections.abc import Mapping
@ -7,6 +6,7 @@ from datetime import datetime
from typing import Any
from flask import Request, Response
from pydantic import TypeAdapter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse
@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog])
class TriggerSubscriptionBuilderService:
"""Service for managing trigger providers and credentials"""
@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService:
cache_key = cls.encode_cache_key(endpoint_id)
subscription_cache = redis_client.get(cache_key)
if subscription_cache:
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
return SubscriptionBuilder.model_validate_json(subscription_cache)
return None
@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService:
)
key = f"trigger:subscription:builder:logs:{endpoint_id}"
logs = json.loads(redis_client.get(key) or "[]")
logs.append(log.model_dump(mode="json"))
logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]")
logs.append(log)
# Keep last N logs
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
redis_client.setex(
key,
cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__,
_request_logs_adapter.dump_json(logs),
)
@classmethod
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService:
logs_json = redis_client.get(key)
if not logs_json:
return []
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
return _request_logs_adapter.validate_json(logs_json)
@classmethod
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:

View File

@ -1118,7 +1118,7 @@ class WorkflowService:
continue
try:
payload = json.loads(recipient.recipient_payload)
except Exception:
except (json.JSONDecodeError, ValueError):
logger.exception("Failed to parse human input recipient payload for delivery test.")
continue
email = payload.get("email")

View File

@ -93,3 +93,20 @@ class TestUseProxyContext:
assert result == stored
expected_key = "oauth_proxy_context:valid-id"
redis_client.delete.assert_called_once_with(expected_key)
def test_returns_context_with_credential_id(self):
from extensions.ext_redis import redis_client
stored = {
"user_id": "u1",
"tenant_id": "t1",
"plugin_id": "p1",
"provider": "github",
"credential_id": "cred-42",
}
redis_client.get.return_value = json.dumps(stored).encode()
result = OAuthProxyService.use_proxy_context("ctx-with-cred")
assert result["credential_id"] == "cred-42"
assert result["tenant_id"] == "t1"

View File

@ -13,6 +13,7 @@ from datetime import datetime
from unittest.mock import Mock, create_autospec, patch
import pytest
from pydantic import ValidationError
from sqlalchemy import Column, Integer, MetaData, String, Table
from libs.archive_storage import ArchiveStorageNotConfiguredError
@ -292,7 +293,7 @@ class TestLoadManifestFromZip:
zip_buffer.seek(0)
with zipfile.ZipFile(zip_buffer, "r") as archive:
with pytest.raises(json.JSONDecodeError):
with pytest.raises(ValidationError):
WorkflowRunRestore._load_manifest_from_zip(archive)