mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 11:37:11 +08:00
refactor(api): replace json.loads with Pydantic validation in services layer (#33704)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
456684dfc3
commit
944db46d4f
@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource):
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
credential_id: str | None = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
|
||||
@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
|
||||
@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
subscription_builder_id: str = context["subscription_builder_id"]
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
|
||||
@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
account_id: str
|
||||
email: str
|
||||
workspace_id: str
|
||||
|
||||
|
||||
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -1571,7 +1581,7 @@ class RegisterService:
|
||||
@classmethod
|
||||
def get_invitation_by_token(
|
||||
cls, token: str, workspace_id: str | None = None, email: str | None = None
|
||||
) -> dict[str, str] | None:
|
||||
) -> InvitationData | None:
|
||||
if workspace_id is not None and email is not None:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||
@ -1590,7 +1600,7 @@ class RegisterService:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
invitation: dict = json.loads(data)
|
||||
invitation = _invitation_adapter.validate_json(data)
|
||||
return invitation
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@ -17,7 +17,7 @@ from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||
SQLAlchemyExecutionExtraContentRepository,
|
||||
@ -31,6 +31,8 @@ from services.errors.message import (
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict)
|
||||
|
||||
|
||||
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
@ -286,7 +288,9 @@ class MessageService:
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
conversation_override_model_configs = _app_model_config_adapter.validate_json(
|
||||
conversation.override_model_configs
|
||||
)
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -168,10 +167,10 @@ class ModelLoadBalancingService:
|
||||
|
||||
try:
|
||||
if load_balancing_config.encrypted_config:
|
||||
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
|
||||
credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config)
|
||||
else:
|
||||
credentials = {}
|
||||
except JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
credentials = {}
|
||||
|
||||
# Get provider credential secret variables
|
||||
@ -256,7 +255,7 @@ class ModelLoadBalancingService:
|
||||
credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||
else:
|
||||
credentials = {}
|
||||
except JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
credentials = {}
|
||||
|
||||
# Get credential form schemas from model credential schema or provider credential schema
|
||||
@ -575,7 +574,7 @@ class ModelLoadBalancingService:
|
||||
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||
else:
|
||||
original_credentials = {}
|
||||
except JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
|
||||
@ -12,7 +12,9 @@ import click
|
||||
import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.helper import marketplace
|
||||
@ -33,6 +35,14 @@ logger = logging.getLogger(__name__)
|
||||
excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
|
||||
class _TenantPluginRecord(TypedDict):
|
||||
tenant_id: str
|
||||
plugins: list[str]
|
||||
|
||||
|
||||
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
|
||||
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
@ -308,9 +318,8 @@ class PluginMigration:
|
||||
logger.info("Extracting unique plugins from %s", extracted_plugins)
|
||||
with open(extracted_plugins) as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
new_plugin_ids = data.get("plugins", [])
|
||||
for plugin_id in new_plugin_ids:
|
||||
data = _tenant_plugin_adapter.validate_json(line)
|
||||
for plugin_id in data["plugins"]:
|
||||
if plugin_id not in plugin_ids:
|
||||
plugin_ids.append(plugin_id)
|
||||
|
||||
@ -381,21 +390,23 @@ class PluginMigration:
|
||||
Read line by line, and install plugins for each tenant.
|
||||
"""
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
tenant_id = data.get("tenant_id")
|
||||
plugin_ids = data.get("plugins", [])
|
||||
current_not_installed = {
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": [],
|
||||
}
|
||||
data = _tenant_plugin_adapter.validate_json(line)
|
||||
tenant_id = data["tenant_id"]
|
||||
plugin_ids = data["plugins"]
|
||||
plugin_not_exist: list[str] = []
|
||||
# get plugin unique identifier
|
||||
for plugin_id in plugin_ids:
|
||||
unique_identifier = plugins.get(plugin_id)
|
||||
if unique_identifier:
|
||||
current_not_installed["plugin_not_exist"].append(plugin_id)
|
||||
plugin_not_exist.append(plugin_id)
|
||||
|
||||
if current_not_installed["plugin_not_exist"]:
|
||||
not_installed.append(current_not_installed)
|
||||
if plugin_not_exist:
|
||||
not_installed.append(
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": plugin_not_exist,
|
||||
}
|
||||
)
|
||||
|
||||
thread_pool.submit(install, tenant_id, plugin_ids)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ back to the database.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import zipfile
|
||||
@ -17,8 +16,23 @@ from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class _TableInfo(TypedDict, total=False):
|
||||
row_count: int
|
||||
|
||||
|
||||
class ArchiveManifest(TypedDict, total=False):
|
||||
tables: dict[str, _TableInfo]
|
||||
schema_version: str
|
||||
|
||||
|
||||
_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest)
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -239,12 +253,12 @@ class WorkflowRunRestore:
|
||||
return self.workflow_run_repo
|
||||
|
||||
@staticmethod
|
||||
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
|
||||
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest:
|
||||
try:
|
||||
data = archive.read("manifest.json")
|
||||
except KeyError as e:
|
||||
raise ValueError("manifest.json missing from archive bundle") from e
|
||||
return json.loads(data.decode("utf-8"))
|
||||
return _manifest_adapter.validate_json(data)
|
||||
|
||||
def _restore_table_records(
|
||||
self,
|
||||
@ -332,7 +346,7 @@ class WorkflowRunRestore:
|
||||
|
||||
return result
|
||||
|
||||
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
|
||||
def _get_schema_version(self, manifest: ArchiveManifest) -> str:
|
||||
schema_version = manifest.get("schema_version")
|
||||
if not schema_version:
|
||||
logger.warning("Manifest missing schema_version; defaulting to 1.0")
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool])
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
||||
@ -53,7 +55,7 @@ class ToolTransformService:
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except Exception:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return icon
|
||||
@ -247,8 +249,8 @@ class ToolTransformService:
|
||||
|
||||
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
||||
try:
|
||||
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools)
|
||||
except (ValidationError, ValueError):
|
||||
mcp_tools = []
|
||||
# Add additional fields specific to the transform
|
||||
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
@ -7,6 +6,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse
|
||||
@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog])
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService:
|
||||
cache_key = cls.encode_cache_key(endpoint_id)
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||
return SubscriptionBuilder.model_validate_json(subscription_cache)
|
||||
|
||||
return None
|
||||
|
||||
@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService:
|
||||
)
|
||||
|
||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||
logs = json.loads(redis_client.get(key) or "[]")
|
||||
logs.append(log.model_dump(mode="json"))
|
||||
logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]")
|
||||
logs.append(log)
|
||||
|
||||
# Keep last N logs
|
||||
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
||||
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
|
||||
redis_client.setex(
|
||||
key,
|
||||
cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__,
|
||||
_request_logs_adapter.dump_json(logs),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||
@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService:
|
||||
logs_json = redis_client.get(key)
|
||||
if not logs_json:
|
||||
return []
|
||||
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
||||
return _request_logs_adapter.validate_json(logs_json)
|
||||
|
||||
@classmethod
|
||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
|
||||
@ -1118,7 +1118,7 @@ class WorkflowService:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(recipient.recipient_payload)
|
||||
except Exception:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.exception("Failed to parse human input recipient payload for delivery test.")
|
||||
continue
|
||||
email = payload.get("email")
|
||||
|
||||
@ -93,3 +93,20 @@ class TestUseProxyContext:
|
||||
assert result == stored
|
||||
expected_key = "oauth_proxy_context:valid-id"
|
||||
redis_client.delete.assert_called_once_with(expected_key)
|
||||
|
||||
def test_returns_context_with_credential_id(self):
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
stored = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"plugin_id": "p1",
|
||||
"provider": "github",
|
||||
"credential_id": "cred-42",
|
||||
}
|
||||
redis_client.get.return_value = json.dumps(stored).encode()
|
||||
|
||||
result = OAuthProxyService.use_proxy_context("ctx-with-cred")
|
||||
|
||||
assert result["credential_id"] == "cred-42"
|
||||
assert result["tenant_id"] == "t1"
|
||||
|
||||
@ -13,6 +13,7 @@ from datetime import datetime
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table
|
||||
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError
|
||||
@ -292,7 +293,7 @@ class TestLoadManifestFromZip:
|
||||
zip_buffer.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, "r") as archive:
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowRunRestore._load_manifest_from_zip(archive)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user