mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
refactor(api): replace json.loads with Pydantic validation in services layer (#33704)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
456684dfc3
commit
944db46d4f
@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource):
|
|||||||
if context is None:
|
if context is None:
|
||||||
raise Forbidden("Invalid context_id")
|
raise Forbidden("Invalid context_id")
|
||||||
|
|
||||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
user_id: str = context["user_id"]
|
||||||
|
tenant_id: str = context["tenant_id"]
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
plugin_id = datasource_provider_id.plugin_id
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource):
|
|||||||
system_credentials=oauth_client_params,
|
system_credentials=oauth_client_params,
|
||||||
request=request,
|
request=request,
|
||||||
)
|
)
|
||||||
credential_id = context.get("credential_id")
|
credential_id: str | None = context.get("credential_id")
|
||||||
if credential_id:
|
if credential_id:
|
||||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource):
|
|||||||
name=oauth_response.metadata.get("name") or None,
|
name=oauth_response.metadata.get("name") or None,
|
||||||
expire_at=oauth_response.expires_at,
|
expire_at=oauth_response.expires_at,
|
||||||
credentials=dict(oauth_response.credentials),
|
credentials=dict(oauth_response.credentials),
|
||||||
credential_id=context.get("credential_id"),
|
credential_id=credential_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
datasource_provider_service.add_datasource_oauth_provider(
|
datasource_provider_service.add_datasource_oauth_provider(
|
||||||
|
|||||||
@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource):
|
|||||||
tool_provider = ToolProviderID(provider)
|
tool_provider = ToolProviderID(provider)
|
||||||
plugin_id = tool_provider.plugin_id
|
plugin_id = tool_provider.plugin_id
|
||||||
provider_name = tool_provider.provider_name
|
provider_name = tool_provider.provider_name
|
||||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
user_id: str = context["user_id"]
|
||||||
|
tenant_id: str = context["tenant_id"]
|
||||||
|
|
||||||
oauth_handler = OAuthHandler()
|
oauth_handler = OAuthHandler()
|
||||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||||
|
|||||||
@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
|
|||||||
provider_id = TriggerProviderID(provider)
|
provider_id = TriggerProviderID(provider)
|
||||||
plugin_id = provider_id.plugin_id
|
plugin_id = provider_id.plugin_id
|
||||||
provider_name = provider_id.provider_name
|
provider_name = provider_id.provider_name
|
||||||
user_id = context.get("user_id")
|
user_id: str = context["user_id"]
|
||||||
tenant_id = context.get("tenant_id")
|
tenant_id: str = context["tenant_id"]
|
||||||
subscription_builder_id = context.get("subscription_builder_id")
|
subscription_builder_id: str = context["subscription_builder_id"]
|
||||||
|
|
||||||
# Get OAuth client configuration
|
# Get OAuth client configuration
|
||||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||||
|
|||||||
@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta
|
|||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class InvitationData(TypedDict):
|
||||||
|
account_id: str
|
||||||
|
email: str
|
||||||
|
workspace_id: str
|
||||||
|
|
||||||
|
|
||||||
|
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -1571,7 +1581,7 @@ class RegisterService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_invitation_by_token(
|
def get_invitation_by_token(
|
||||||
cls, token: str, workspace_id: str | None = None, email: str | None = None
|
cls, token: str, workspace_id: str | None = None, email: str | None = None
|
||||||
) -> dict[str, str] | None:
|
) -> InvitationData | None:
|
||||||
if workspace_id is not None and email is not None:
|
if workspace_id is not None and email is not None:
|
||||||
email_hash = sha256(email.encode()).hexdigest()
|
email_hash = sha256(email.encode()).hexdigest()
|
||||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||||
@ -1590,7 +1600,7 @@ class RegisterService:
|
|||||||
if not data:
|
if not data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
invitation: dict = json.loads(data)
|
invitation = _invitation_adapter.validate_json(data)
|
||||||
return invitation
|
return invitation
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import json
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
@ -17,7 +17,7 @@ from extensions.ext_database import db
|
|||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.enums import FeedbackFromSource, FeedbackRating
|
from models.enums import FeedbackFromSource, FeedbackRating
|
||||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback
|
||||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||||
SQLAlchemyExecutionExtraContentRepository,
|
SQLAlchemyExecutionExtraContentRepository,
|
||||||
@ -31,6 +31,8 @@ from services.errors.message import (
|
|||||||
)
|
)
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict)
|
||||||
|
|
||||||
|
|
||||||
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
|
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
|
||||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
@ -286,7 +288,9 @@ class MessageService:
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
conversation_override_model_configs = _app_model_config_adapter.validate_json(
|
||||||
|
conversation.override_model_configs
|
||||||
|
)
|
||||||
app_model_config = AppModelConfig(
|
app_model_config = AppModelConfig(
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from json import JSONDecodeError
|
from typing import Any, Union
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.entities.provider_entities import (
|
from graphon.model_runtime.entities.provider_entities import (
|
||||||
@ -168,10 +167,10 @@ class ModelLoadBalancingService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if load_balancing_config.encrypted_config:
|
if load_balancing_config.encrypted_config:
|
||||||
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
|
credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config)
|
||||||
else:
|
else:
|
||||||
credentials = {}
|
credentials = {}
|
||||||
except JSONDecodeError:
|
except (json.JSONDecodeError, ValueError):
|
||||||
credentials = {}
|
credentials = {}
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
@ -256,7 +255,7 @@ class ModelLoadBalancingService:
|
|||||||
credentials = json.loads(load_balancing_model_config.encrypted_config)
|
credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||||
else:
|
else:
|
||||||
credentials = {}
|
credentials = {}
|
||||||
except JSONDecodeError:
|
except (json.JSONDecodeError, ValueError):
|
||||||
credentials = {}
|
credentials = {}
|
||||||
|
|
||||||
# Get credential form schemas from model credential schema or provider credential schema
|
# Get credential form schemas from model credential schema or provider credential schema
|
||||||
@ -575,7 +574,7 @@ class ModelLoadBalancingService:
|
|||||||
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||||
else:
|
else:
|
||||||
original_credentials = {}
|
original_credentials = {}
|
||||||
except JSONDecodeError:
|
except (json.JSONDecodeError, ValueError):
|
||||||
original_credentials = {}
|
original_credentials = {}
|
||||||
|
|
||||||
# encrypt credentials
|
# encrypt credentials
|
||||||
|
|||||||
@ -12,7 +12,9 @@ import click
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import tqdm
|
import tqdm
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.helper import marketplace
|
from core.helper import marketplace
|
||||||
@ -33,6 +35,14 @@ logger = logging.getLogger(__name__)
|
|||||||
excluded_providers = ["time", "audio", "code", "webscraper"]
|
excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||||
|
|
||||||
|
|
||||||
|
class _TenantPluginRecord(TypedDict):
|
||||||
|
tenant_id: str
|
||||||
|
plugins: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
|
||||||
|
|
||||||
|
|
||||||
class PluginMigration:
|
class PluginMigration:
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_plugins(cls, filepath: str, workers: int):
|
def extract_plugins(cls, filepath: str, workers: int):
|
||||||
@ -308,9 +318,8 @@ class PluginMigration:
|
|||||||
logger.info("Extracting unique plugins from %s", extracted_plugins)
|
logger.info("Extracting unique plugins from %s", extracted_plugins)
|
||||||
with open(extracted_plugins) as f:
|
with open(extracted_plugins) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = _tenant_plugin_adapter.validate_json(line)
|
||||||
new_plugin_ids = data.get("plugins", [])
|
for plugin_id in data["plugins"]:
|
||||||
for plugin_id in new_plugin_ids:
|
|
||||||
if plugin_id not in plugin_ids:
|
if plugin_id not in plugin_ids:
|
||||||
plugin_ids.append(plugin_id)
|
plugin_ids.append(plugin_id)
|
||||||
|
|
||||||
@ -381,21 +390,23 @@ class PluginMigration:
|
|||||||
Read line by line, and install plugins for each tenant.
|
Read line by line, and install plugins for each tenant.
|
||||||
"""
|
"""
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = _tenant_plugin_adapter.validate_json(line)
|
||||||
tenant_id = data.get("tenant_id")
|
tenant_id = data["tenant_id"]
|
||||||
plugin_ids = data.get("plugins", [])
|
plugin_ids = data["plugins"]
|
||||||
current_not_installed = {
|
plugin_not_exist: list[str] = []
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"plugin_not_exist": [],
|
|
||||||
}
|
|
||||||
# get plugin unique identifier
|
# get plugin unique identifier
|
||||||
for plugin_id in plugin_ids:
|
for plugin_id in plugin_ids:
|
||||||
unique_identifier = plugins.get(plugin_id)
|
unique_identifier = plugins.get(plugin_id)
|
||||||
if unique_identifier:
|
if unique_identifier:
|
||||||
current_not_installed["plugin_not_exist"].append(plugin_id)
|
plugin_not_exist.append(plugin_id)
|
||||||
|
|
||||||
if current_not_installed["plugin_not_exist"]:
|
if plugin_not_exist:
|
||||||
not_installed.append(current_not_installed)
|
not_installed.append(
|
||||||
|
{
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"plugin_not_exist": plugin_not_exist,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
thread_pool.submit(install, tenant_id, plugin_ids)
|
thread_pool.submit(install, tenant_id, plugin_ids)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ back to the database.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import zipfile
|
import zipfile
|
||||||
@ -17,8 +16,23 @@ from datetime import datetime
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class _TableInfo(TypedDict, total=False):
|
||||||
|
row_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveManifest(TypedDict, total=False):
|
||||||
|
tables: dict[str, _TableInfo]
|
||||||
|
schema_version: str
|
||||||
|
|
||||||
|
|
||||||
|
_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest)
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -239,12 +253,12 @@ class WorkflowRunRestore:
|
|||||||
return self.workflow_run_repo
|
return self.workflow_run_repo
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
|
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest:
|
||||||
try:
|
try:
|
||||||
data = archive.read("manifest.json")
|
data = archive.read("manifest.json")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError("manifest.json missing from archive bundle") from e
|
raise ValueError("manifest.json missing from archive bundle") from e
|
||||||
return json.loads(data.decode("utf-8"))
|
return _manifest_adapter.validate_json(data)
|
||||||
|
|
||||||
def _restore_table_records(
|
def _restore_table_records(
|
||||||
self,
|
self,
|
||||||
@ -332,7 +346,7 @@ class WorkflowRunRestore:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
|
def _get_schema_version(self, manifest: ArchiveManifest) -> str:
|
||||||
schema_version = manifest.get("schema_version")
|
schema_version = manifest.get("schema_version")
|
||||||
if not schema_version:
|
if not schema_version:
|
||||||
logger.warning("Manifest missing schema_version; defaulting to 1.0")
|
logger.warning("Manifest missing schema_version; defaulting to 1.0")
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool])
|
||||||
|
|
||||||
|
|
||||||
class ToolTransformService:
|
class ToolTransformService:
|
||||||
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
||||||
@ -53,7 +55,7 @@ class ToolTransformService:
|
|||||||
if isinstance(icon, str):
|
if isinstance(icon, str):
|
||||||
return json.loads(icon)
|
return json.loads(icon)
|
||||||
return icon
|
return icon
|
||||||
except Exception:
|
except (json.JSONDecodeError, ValueError):
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
elif provider_type == ToolProviderType.MCP:
|
elif provider_type == ToolProviderType.MCP:
|
||||||
return icon
|
return icon
|
||||||
@ -247,8 +249,8 @@ class ToolTransformService:
|
|||||||
|
|
||||||
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
||||||
try:
|
try:
|
||||||
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools)
|
||||||
except (ValidationError, json.JSONDecodeError):
|
except (ValidationError, ValueError):
|
||||||
mcp_tools = []
|
mcp_tools = []
|
||||||
# Add additional fields specific to the transform
|
# Add additional fields specific to the transform
|
||||||
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -7,6 +6,7 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import Request, Response
|
from flask import Request, Response
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.entities.request import TriggerDispatchResponse
|
from core.plugin.entities.request import TriggerDispatchResponse
|
||||||
@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog])
|
||||||
|
|
||||||
|
|
||||||
class TriggerSubscriptionBuilderService:
|
class TriggerSubscriptionBuilderService:
|
||||||
"""Service for managing trigger providers and credentials"""
|
"""Service for managing trigger providers and credentials"""
|
||||||
@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService:
|
|||||||
cache_key = cls.encode_cache_key(endpoint_id)
|
cache_key = cls.encode_cache_key(endpoint_id)
|
||||||
subscription_cache = redis_client.get(cache_key)
|
subscription_cache = redis_client.get(cache_key)
|
||||||
if subscription_cache:
|
if subscription_cache:
|
||||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
return SubscriptionBuilder.model_validate_json(subscription_cache)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||||
logs = json.loads(redis_client.get(key) or "[]")
|
logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]")
|
||||||
logs.append(log.model_dump(mode="json"))
|
logs.append(log)
|
||||||
|
|
||||||
# Keep last N logs
|
# Keep last N logs
|
||||||
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
||||||
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
|
redis_client.setex(
|
||||||
|
key,
|
||||||
|
cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__,
|
||||||
|
_request_logs_adapter.dump_json(logs),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||||
@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService:
|
|||||||
logs_json = redis_client.get(key)
|
logs_json = redis_client.get(key)
|
||||||
if not logs_json:
|
if not logs_json:
|
||||||
return []
|
return []
|
||||||
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
return _request_logs_adapter.validate_json(logs_json)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||||
|
|||||||
@ -1118,7 +1118,7 @@ class WorkflowService:
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
payload = json.loads(recipient.recipient_payload)
|
payload = json.loads(recipient.recipient_payload)
|
||||||
except Exception:
|
except (json.JSONDecodeError, ValueError):
|
||||||
logger.exception("Failed to parse human input recipient payload for delivery test.")
|
logger.exception("Failed to parse human input recipient payload for delivery test.")
|
||||||
continue
|
continue
|
||||||
email = payload.get("email")
|
email = payload.get("email")
|
||||||
|
|||||||
@ -93,3 +93,20 @@ class TestUseProxyContext:
|
|||||||
assert result == stored
|
assert result == stored
|
||||||
expected_key = "oauth_proxy_context:valid-id"
|
expected_key = "oauth_proxy_context:valid-id"
|
||||||
redis_client.delete.assert_called_once_with(expected_key)
|
redis_client.delete.assert_called_once_with(expected_key)
|
||||||
|
|
||||||
|
def test_returns_context_with_credential_id(self):
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
stored = {
|
||||||
|
"user_id": "u1",
|
||||||
|
"tenant_id": "t1",
|
||||||
|
"plugin_id": "p1",
|
||||||
|
"provider": "github",
|
||||||
|
"credential_id": "cred-42",
|
||||||
|
}
|
||||||
|
redis_client.get.return_value = json.dumps(stored).encode()
|
||||||
|
|
||||||
|
result = OAuthProxyService.use_proxy_context("ctx-with-cred")
|
||||||
|
|
||||||
|
assert result["credential_id"] == "cred-42"
|
||||||
|
assert result["tenant_id"] == "t1"
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from datetime import datetime
|
|||||||
from unittest.mock import Mock, create_autospec, patch
|
from unittest.mock import Mock, create_autospec, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import Column, Integer, MetaData, String, Table
|
from sqlalchemy import Column, Integer, MetaData, String, Table
|
||||||
|
|
||||||
from libs.archive_storage import ArchiveStorageNotConfiguredError
|
from libs.archive_storage import ArchiveStorageNotConfiguredError
|
||||||
@ -292,7 +293,7 @@ class TestLoadManifestFromZip:
|
|||||||
zip_buffer.seek(0)
|
zip_buffer.seek(0)
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_buffer, "r") as archive:
|
with zipfile.ZipFile(zip_buffer, "r") as archive:
|
||||||
with pytest.raises(json.JSONDecodeError):
|
with pytest.raises(ValidationError):
|
||||||
WorkflowRunRestore._load_manifest_from_zip(archive)
|
WorkflowRunRestore._load_manifest_from_zip(archive)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user