diff --git a/api/.ruff.toml b/api/.ruff.toml index 8db0cbcb21..3301452ad9 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -53,6 +53,7 @@ select = [ "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module "S311", # suspicious-non-cryptographic-random-usage, + "TID", # flake8-tidy-imports ] @@ -88,6 +89,7 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false + "TID252", # allow relative imports from parent modules ] [lint.per-file-ignores] @@ -109,10 +111,20 @@ ignore = [ "S110", # allow ignoring exceptions in tests code (currently) ] +"controllers/console/explore/trial.py" = ["TID251"] +"controllers/console/human_input_form.py" = ["TID251"] +"controllers/web/human_input_form.py" = ["TID251"] [lint.pyflakes] allowed-unused-imports = [ - "_pytest.monkeypatch", "tests.integration_tests", "tests.unit_tests", ] + +[lint.flake8-tidy-imports] + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] +msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] +msg = "Use Pydantic payload/query models instead of reqparse." diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d34fd5088d..29b6b64b94 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,10 +1,9 @@ import json import logging from typing import Any, Literal, cast -from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel): class WorkflowRunQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) @@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel): start_node_title: str +class RagPipelineRecommendedPluginQuery(BaseModel): + type: str = "all" + + register_schema_models( console_ns, DraftWorkflowSyncPayload, @@ -135,6 +138,7 @@ register_schema_models( NodeIdQuery, WorkflowRunQuery, DatasourceVariablesPayload, + RagPipelineRecommendedPluginQuery, ) @@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, location="args", required=False, default="all") - args = parser.parse_args() - type = args["type"] + query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) return recommended_plugins diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e9e7b72718..5bfa895849 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,16 +1,16 @@ import io import logging +from typing import Any, Literal from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_restx import ( - Resource, - reqparse, -) +from flask_restx import Resource +from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -26,8 +26,9 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from libs.helper import StrLen, alphanumeric, uuid_value +from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -52,24 +53,209 @@ def is_valid_url(url: str) -> bool: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] except (ValueError, TypeError): - # ValueError: Invalid URL format - # TypeError: url is not a string return False -parser_tool = reqparse.RequestParser().add_argument( - "type", - type=str, - choices=["builtin", "model", "api", "workflow", "mcp"], - required=False, - nullable=True, - location="args", +class ToolProviderListQuery(BaseModel): + type: Literal["builtin", "model", "api", "workflow", "mcp"] | None = None + + +class BuiltinToolCredentialDeletePayload(BaseModel): + credential_id: str + + +class BuiltinToolAddPayload(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + type: CredentialType + + +class BuiltinToolUpdatePayload(BaseModel): + credential_id: str + credentials: dict[str, Any] | None = None + name: str | None = Field(default=None, max_length=30) + + +class ApiToolProviderBasePayload(BaseModel): + credentials: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + provider: str + icon: dict[str, Any] + privacy_policy: str | None = None + labels: list[str] | None = None + custom_disclaimer: str = "" + + +class ApiToolProviderAddPayload(ApiToolProviderBasePayload): + pass + + +class ApiToolProviderUpdatePayload(ApiToolProviderBasePayload): + original_provider: str + + +class UrlQuery(BaseModel): + url: HttpUrl + + +class ProviderQuery(BaseModel): + provider: str + + +class ApiToolProviderDeletePayload(BaseModel): + provider: str + + +class ApiToolSchemaPayload(BaseModel): + schema_: str = Field(alias="schema") + + +class ApiToolTestPayload(BaseModel): + tool_name: str + provider_name: str | None = None + credentials: dict[str, Any] + parameters: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + + +class WorkflowToolBasePayload(BaseModel): + name: str + label: str + description: str + icon: dict[str, Any] + parameters: list[WorkflowToolParameterConfiguration] = Field(default_factory=list) + privacy_policy: str | None = "" + labels: list[str] | None = None + + @field_validator("name") + @classmethod + def validate_name(cls, value: str) -> str: + return alphanumeric(value) + + +class WorkflowToolCreatePayload(WorkflowToolBasePayload): + workflow_app_id: str + + @field_validator("workflow_app_id") + @classmethod + def validate_workflow_app_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolUpdatePayload(WorkflowToolBasePayload): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolDeletePayload(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolGetQuery(BaseModel): + workflow_tool_id: str | None = None + workflow_app_id: str | None = None + + @field_validator("workflow_tool_id", "workflow_app_id") + @classmethod + def validate_ids(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + @model_validator(mode="after") + def ensure_one(self) -> "WorkflowToolGetQuery": + if not self.workflow_tool_id and not self.workflow_app_id: + raise ValueError("workflow_tool_id or workflow_app_id is required") + return self + + +class WorkflowToolListQuery(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class BuiltinProviderDefaultCredentialPayload(BaseModel): + id: str + + +class ToolOAuthCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = True + + +class MCPProviderBasePayload(BaseModel): + server_url: str + name: str + icon: str + icon_type: str + icon_background: str = "" + server_identifier: str + configuration: dict[str, Any] | None = Field(default_factory=dict) + headers: dict[str, Any] | None = Field(default_factory=dict) + authentication: dict[str, Any] | None = Field(default_factory=dict) + + +class MCPProviderCreatePayload(MCPProviderBasePayload): + pass + + +class MCPProviderUpdatePayload(MCPProviderBasePayload): + provider_id: str + + +class MCPProviderDeletePayload(BaseModel): + provider_id: str + + +class MCPAuthPayload(BaseModel): + provider_id: str + authorization_code: str | None = None + + +class MCPCallbackQuery(BaseModel): + code: str + state: str + + +register_schema_models( + console_ns, + BuiltinToolCredentialDeletePayload, + BuiltinToolAddPayload, + BuiltinToolUpdatePayload, + ApiToolProviderAddPayload, + ApiToolProviderUpdatePayload, + ApiToolProviderDeletePayload, + ApiToolSchemaPayload, + ApiToolTestPayload, + WorkflowToolCreatePayload, + WorkflowToolUpdatePayload, + WorkflowToolDeletePayload, + BuiltinProviderDefaultCredentialPayload, + ToolOAuthCustomClientPayload, + MCPProviderCreatePayload, + MCPProviderUpdatePayload, + MCPProviderDeletePayload, + MCPAuthPayload, ) @console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): - @console_ns.expect(parser_tool) @setup_required @login_required @account_initialization_required @@ -78,9 +264,10 @@ class ToolProviderListApi(Resource): user_id = user.id - args = parser_tool.parse_args() + raw_args = request.args.to_dict() + query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -110,14 +297,9 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) -parser_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[BuiltinToolCredentialDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -125,26 +307,18 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): _, tenant_id = current_account_with_tenant() - args = parser_delete.parse_args() + payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( tenant_id, provider, - args["credential_id"], + payload.credential_id, ) -parser_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - .add_argument("type", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): - @console_ns.expect(parser_add) + @console_ns.expect(console_ns.models[BuiltinToolAddPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -153,32 +327,21 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - args = parser_add.parse_args() - - if args["type"] not in CredentialType.values(): - raise ValueError(f"Invalid credential type: {args['type']}") + payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credentials=args["credentials"], - name=args["name"], - api_type=CredentialType.of(args["type"]), + credentials=payload.credentials, + name=payload.name, + api_type=CredentialType.of(payload.type), ) -parser_update = ( - reqparse.RequestParser() - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[BuiltinToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -187,15 +350,15 @@ class ToolBuiltinProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_update.parse_args() + payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credential_id=args["credential_id"], - credentials=args.get("credentials", None), - name=args.get("name", ""), + credential_id=payload.credential_id, + credentials=payload.credentials, + name=payload.name or "", ) return result @@ -225,22 +388,9 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -parser_api_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): - @console_ns.expect(parser_api_add) + @console_ns.expect(console_ns.models[ApiToolProviderAddPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -250,28 +400,24 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - args = parser_api_add.parse_args() + payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args["provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args.get("privacy_policy", ""), - args.get("custom_disclaimer", ""), - args.get("labels", []), + payload.provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy or "", + payload.custom_disclaimer or "", + payload.labels or [], ) -parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): - @console_ns.expect(parser_remote) @setup_required @login_required @account_initialization_required @@ -280,23 +426,18 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - args = parser_remote.parse_args() + raw_args = request.args.to_dict() + query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, - args["url"], + str(query.url), ) -parser_tools = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): - @console_ns.expect(parser_tools) @setup_required @login_required @account_initialization_required @@ -305,34 +446,21 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - args = parser_tools.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, - args["provider"], + query.provider, ) ) -parser_api_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("original_provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") - .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): - @console_ns.expect(parser_api_update) + @console_ns.expect(console_ns.models[ApiToolProviderUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -342,31 +470,26 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - args = parser_api_update.parse_args() + payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args["provider"], - args["original_provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args["privacy_policy"], - args["custom_disclaimer"], - args.get("labels", []), + payload.provider, + payload.original_provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy, + payload.custom_disclaimer, + payload.labels or [], ) -parser_api_delete = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): - @console_ns.expect(parser_api_delete) + @console_ns.expect(console_ns.models[ApiToolProviderDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -376,21 +499,17 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - args = parser_api_delete.parse_args() + payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args["provider"], + payload.provider, ) -parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): - @console_ns.expect(parser_get) @setup_required @login_required @account_initialization_required @@ -399,12 +518,13 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - args = parser_get.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args["provider"], + query.provider, ) @@ -423,72 +543,43 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) -parser_schema = reqparse.RequestParser().add_argument( - "schema", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): - @console_ns.expect(parser_schema) + @console_ns.expect(console_ns.models[ApiToolSchemaPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_schema.parse_args() + payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=args["schema"], + schema=payload.schema_, ) -parser_pre = ( - reqparse.RequestParser() - .add_argument("tool_name", type=str, required=True, nullable=False, location="json") - .add_argument("provider_name", type=str, required=False, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): - @console_ns.expect(parser_pre) + @console_ns.expect(console_ns.models[ApiToolTestPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_pre.parse_args() + payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, - args["provider_name"] or "", - args["tool_name"], - args["credentials"], - args["parameters"], - args["schema_type"], - args["schema"], + payload.provider_name or "", + payload.tool_name, + payload.credentials, + payload.parameters, + payload.schema_type, + payload.schema_, ) -parser_create = ( - reqparse.RequestParser() - .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[WorkflowToolCreatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -498,38 +589,25 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - args = parser_create.parse_args() + payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, - workflow_app_id=args["workflow_app_id"], - name=args["name"], - label=args["label"], - icon=args["icon"], - description=args["description"], - parameters=args["parameters"], - privacy_policy=args["privacy_policy"], - labels=args["labels"], + workflow_app_id=payload.workflow_app_id, + name=payload.name, + label=payload.label, + icon=payload.icon, + description=payload.description, + parameters=payload.parameters, + privacy_policy=payload.privacy_policy or "", + labels=payload.labels or [], ) -parser_workflow_update = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): - @console_ns.expect(parser_workflow_update) + @console_ns.expect(console_ns.models[WorkflowToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -538,33 +616,25 @@ class ToolWorkflowProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_workflow_update.parse_args() - - if not args["workflow_tool_id"]: - raise ValueError("incorrect workflow_tool_id") + payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + payload.workflow_tool_id, + payload.name, + payload.label, + payload.icon, + payload.description, + payload.parameters, + payload.privacy_policy or "", + payload.labels or [], ) -parser_workflow_delete = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): - @console_ns.expect(parser_workflow_delete) + @console_ns.expect(console_ns.models[WorkflowToolDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -574,25 +644,17 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - args = parser_workflow_delete.parse_args() + payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], + payload.workflow_tool_id, ) -parser_wf_get = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): - @console_ns.expect(parser_wf_get) @setup_required @login_required @account_initialization_required @@ -601,19 +663,20 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - args = parser_wf_get.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolGetQuery.model_validate(raw_args) - if args.get("workflow_tool_id"): + if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) - elif args.get("workflow_app_id"): + elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args["workflow_app_id"], + query.workflow_app_id, ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") @@ -621,14 +684,8 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) -parser_wf_tools = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): - @console_ns.expect(parser_wf_tools) @setup_required @login_required @account_initialization_required @@ -637,13 +694,14 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - args = parser_wf_tools.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) ) @@ -810,49 +868,39 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_default_cred = reqparse.RequestParser().add_argument( - "id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): - @console_ns.expect(parser_default_cred) + @console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - args = parser_default_cred.parse_args() + payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id ) -parser_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): - @console_ns.expect(parser_custom) + @console_ns.expect(console_ns.models[ToolOAuthCustomClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - args = parser_custom.parse_args() + payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, - client_params=args.get("client_params", {}), - enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + client_params=payload.client_params or {}, + enable_oauth_custom_client=payload.enable_oauth_custom_client + if payload.enable_oauth_custom_client is not None + else True, ) @setup_required @@ -904,49 +952,19 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) -parser_mcp = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_put = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_delete = reqparse.RequestParser().add_argument( - "provider_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): - @console_ns.expect(parser_mcp) + @console_ns.expect(console_ns.models[MCPProviderCreatePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_mcp.parse_args() + payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) user, tenant_id = current_account_with_tenant() # Parse and validate models - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None # 1) Create provider in a short transaction (no network I/O inside) with session_factory.create_session() as session, session.begin(): @@ -954,13 +972,13 @@ class ToolProviderMCPApi(Resource): result = service.create_provider( tenant_id=tenant_id, user_id=user.id, - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, ) @@ -969,8 +987,8 @@ class ToolProviderMCPApi(Resource): # Perform network I/O outside any DB session to avoid holding locks. try: reconnect = MCPToolManageService.reconnect_with_url( - server_url=args["server_url"], - headers=args.get("headers") or {}, + server_url=payload.server_url, + headers=payload.headers or {}, timeout=configuration.timeout, sse_read_timeout=configuration.sse_read_timeout, ) @@ -988,14 +1006,14 @@ class ToolProviderMCPApi(Resource): return jsonable_encoder(result) - @console_ns.expect(parser_mcp_put) + @console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self): - args = parser_mcp_put.parse_args() - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None _, current_tenant_id = current_account_with_tenant() # Step 1: Get provider data for URL validation (short-lived session, no network I/O) @@ -1003,14 +1021,14 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) validation_data = service.get_provider_for_url_validation( - tenant_id=current_tenant_id, provider_id=args["provider_id"] + tenant_id=current_tenant_id, provider_id=payload.provider_id ) # Step 2: Perform URL validation with network I/O OUTSIDE of any database session # This prevents holding database locks during potentially slow network operations validation_result = MCPToolManageService.validate_server_url_standalone( tenant_id=current_tenant_id, - new_server_url=args["server_url"], + new_server_url=payload.server_url, validation_data=validation_data, ) @@ -1019,14 +1037,14 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + provider_id=payload.provider_id, + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, validation_result=validation_result, @@ -1034,37 +1052,30 @@ class ToolProviderMCPApi(Resource): return {"result": "success"} - @console_ns.expect(parser_mcp_delete) + @console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def delete(self): - args = parser_mcp_delete.parse_args() + payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) return {"result": "success"} -parser_auth = ( - reqparse.RequestParser() - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): - @console_ns.expect(parser_auth) + @console_ns.expect(console_ns.models[MCPAuthPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_auth.parse_args() - provider_id = args["provider_id"] + payload = MCPAuthPayload.model_validate(console_ns.payload or {}) + provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): @@ -1102,7 +1113,7 @@ class ToolMCPAuthApi(Resource): # Pass the extracted OAuth metadata hints to auth() auth_result = auth( provider_entity, - args.get("authorization_code"), + payload.authorization_code, resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) @@ -1167,20 +1178,13 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -parser_cb = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, nullable=False, location="args") - .add_argument("state", type=str, required=True, nullable=False, location="args") -) - - @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): - @console_ns.expect(parser_cb) def get(self): - args = parser_cb.parse_args() - state_key = args["state"] - authorization_code = args["code"] + raw_args = request.args.to_dict() + query = MCPCallbackQuery.model_validate(raw_args) + state_key = query.state + authorization_code = query.code # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b3836f3a47..9d8431f066 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -30,6 +30,7 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: str | None = Field(default=None, description="Conversation UUID") + conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 62e8258e25..8e29c9ff0f 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,4 @@ from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -23,12 +22,13 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService class ConversationListQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( default="-updated_at", description="Sort order for conversations" @@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel): class ConversationVariablesQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") variable_name: str | None = Field( default=None, description="Filter variables by name", min_length=1, max_length=255 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 8981bbd7d5..2aaf920efb 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,6 +1,5 @@ import logging from typing import Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 8dbb690901..97a70f5d0e 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,7 +1,10 @@ -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.common.schema import register_schema_model +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +register_schema_model(service_api_ns, HitTestingPayload) + @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): 404: "Dataset not found", } ) + @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Perform hit testing on a dataset. diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 188da0c32d..6d75df3603 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity class WorkflowToolConfigurationUtils: - @classmethod - def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): - for configuration in configurations: - WorkflowToolParameterConfiguration.model_validate(configuration) - @classmethod def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ab5d5480df..6d84d4e250 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,8 +1,6 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -10,8 +8,8 @@ from sqlalchemy.orm import Session from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -38,12 +36,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -75,7 +71,7 @@ class WorkflowToolManageService: label=label, icon=json.dumps(icon), description=description, - parameter_configuration=json.dumps(parameters), + parameter_configuration=json.dumps([p.model_dump() for p in parameters]), privacy_policy=privacy_policy, version=workflow.version, ) @@ -104,7 +100,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): @@ -122,8 +118,6 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -162,7 +156,7 @@ class WorkflowToolManageService: workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 3d46735a1a..3c0a660e7c 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -3,7 +3,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import ValidationError +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -130,20 +132,24 @@ class TestWorkflowToolManageService: def _create_test_workflow_tool_parameters(self): """Helper method to create valid workflow tool parameters.""" return [ - { - "name": "input_text", - "description": "Input text for processing", - "form": "form", - "type": "string", - "required": True, - }, - { - "name": "output_format", - "description": "Output format specification", - "form": "form", - "type": "select", - "required": False, - }, + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + "description": "Input text for processing", + "form": "form", + "type": "string", + "required": True, + } + ), + WorkflowToolParameterConfiguration.model_validate( + { + "name": "output_format", + "description": "Output format specification", + "form": "form", + "type": "select", + "required": False, + } + ), ] def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -208,7 +214,7 @@ class TestWorkflowToolManageService: assert created_tool_provider.label == tool_label assert created_tool_provider.icon == json.dumps(tool_icon) assert created_tool_provider.description == tool_description - assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) + assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters]) assert created_tool_provider.privacy_policy == tool_privacy_policy assert created_tool_provider.version == workflow.version assert created_tool_provider.user_id == account.id @@ -353,18 +359,9 @@ class TestWorkflowToolManageService: app, account, workflow = self._create_test_app_and_account( db_session_with_containers, mock_external_service_dependencies ) - - # Setup invalid workflow tool parameters (missing required fields) - invalid_parameters = [ - { - "name": "input_text", - # Missing description and form fields - "type": "string", - "required": True, - } - ] # Attempt to create workflow tool with invalid parameters - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValidationError) as exc_info: + # Setup invalid workflow tool parameters (missing required fields) WorkflowToolManageService.create_workflow_tool( user_id=account.id, tenant_id=account.current_tenant.id, @@ -373,7 +370,16 @@ class TestWorkflowToolManageService: label=fake.word(), icon={"type": "emoji", "emoji": "🔧"}, description=fake.text(max_nb_chars=200), - parameters=invalid_parameters, + parameters=[ + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + # Missing description and form fields + "type": "string", + "required": True, + } + ) + ], ) # Verify error message contains validation error @@ -579,11 +585,12 @@ class TestWorkflowToolManageService: # Verify database state was updated db.session.refresh(created_tool) + assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label assert created_tool.icon == json.dumps(updated_tool_icon) assert created_tool.description == updated_tool_description - assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) + assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters]) assert created_tool.privacy_policy == updated_tool_privacy_policy assert created_tool.version == workflow.version assert created_tool.updated_at is not None @@ -750,13 +757,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILE type file_parameters = [ - { - "name": "document", - "description": "Upload a document", - "form": "form", - "type": "file", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "document", + "description": "Upload a document", + "form": "form", + "type": "file", + "required": False, + } + ) ] # Execute the method under test @@ -823,13 +832,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILES type files_parameters = [ - { - "name": "documents", - "description": "Upload multiple documents", - "form": "form", - "type": "files", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "documents", + "description": "Upload multiple documents", + "form": "form", + "type": "files", + "required": False, + } + ) ] # Execute the method under test