From 3225ca8337d39164981655718e3e85f016588137 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 16:32:16 +0900 Subject: [PATCH 01/28] reqparse is gone --- .../console/explore/installed_app.py | 30 +- api/controllers/console/extension.py | 78 +- api/controllers/console/tag/tags.py | 87 +-- .../workspace/load_balancing_config.py | 59 +- .../console/workspace/tool_providers.py | 678 +++++++++--------- .../console/workspace/trigger_providers.py | 108 +-- api/controllers/web/app.py | 32 +- api/controllers/web/audio.py | 37 +- api/controllers/web/completion.py | 94 +-- api/controllers/web/conversation.py | 64 +- api/controllers/web/forgot_password.py | 80 ++- api/controllers/web/login.py | 68 +- api/controllers/web/message.py | 97 ++- api/controllers/web/remote_files.py | 13 +- api/controllers/web/saved_message.py | 41 +- api/controllers/web/workflow.py | 17 +- 16 files changed, 798 insertions(+), 785 deletions(-) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3c95779475..3f23fd55b4 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,8 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -18,6 +19,15 @@ from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService + +class InstalledAppCreatePayload(BaseModel): + app_id: str + + +class InstalledAppUpdatePayload(BaseModel): + is_pinned: bool | None = None + + logger = logging.getLogger(__name__) @@ -105,16 +115,15 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") def post(self): - parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") - args = parser.parse_args() + payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() if recommended_app is None: raise NotFound("App not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == payload.app_id).first() if app is None: raise NotFound("App not found") @@ -124,7 +133,7 @@ class InstalledAppsListApi(Resource): installed_app = ( db.session.query(InstalledApp) - .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) .first() ) @@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource): recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args["app_id"], + app_id=payload.app_id, tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, @@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource): return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): - parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) - args = parser.parse_args() + payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {}) commit_args = False - if "is_pinned" in args: - installed_app.is_pinned = args["is_pinned"] + if payload.is_pinned is not None: + installed_app.is_pinned = payload.is_pinned commit_args = True if commit_args: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 08f29b4655..c2cb4578ea 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,7 +1,10 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from constants import HIDDEN_VALUE from controllers.console import console_ns +from controllers.common.schema import register_schema_models from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import current_account_with_tenant, login_required @@ -9,6 +12,28 @@ from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService + +class CodeBasedExtensionQuery(BaseModel): + module: str + + +class APIBasedExtensionPayload(BaseModel): + name: str = Field(description="Extension name") + api_endpoint: str = Field(description="API endpoint URL") + api_key: str = Field(description="API key for authentication") + + +class CreateAPIBasedExtensionPayload(APIBasedExtensionPayload): + pass + + +class UpdateAPIBasedExtensionPayload(APIBasedExtensionPayload): + pass + + +register_schema_models(console_ns, CreateAPIBasedExtensionPayload, UpdateAPIBasedExtensionPayload) + + api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) @@ -18,11 +43,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m class CodeBasedExtensionAPI(Resource): @console_ns.doc("get_code_based_extension") @console_ns.doc(description="Get code-based extension data by module name") - @console_ns.expect( - console_ns.parser().add_argument( - "module", type=str, required=True, location="args", help="Extension module name" - ) - ) + @console_ns.doc(params={"module": "Extension module name"}) @console_ns.response( 200, "Success", @@ -35,10 +56,9 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args") - args = parser.parse_args() + query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} + return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} @console_ns.route("/api-based-extension") @@ -56,30 +76,21 @@ class APIBasedExtensionAPI(Resource): @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") - @console_ns.expect( - console_ns.model( - "CreateAPIBasedExtensionRequest", - { - "name": fields.String(required=True, description="Extension name"), - "api_endpoint": fields.String(required=True, description="API endpoint URL"), - "api_key": fields.String(required=True, description="API key for authentication"), - }, - ) - ) + @console_ns.expect(console_ns.models[CreateAPIBasedExtensionPayload.__name__]) @console_ns.response(201, "Extension created successfully", api_based_extension_model) @setup_required @login_required @account_initialization_required @marshal_with(api_based_extension_model) def post(self): - args = console_ns.payload + payload = CreateAPIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() extension_data = APIBasedExtension( tenant_id=current_tenant_id, - name=args["name"], - api_endpoint=args["api_endpoint"], - api_key=args["api_key"], + name=payload.name, + api_endpoint=payload.api_endpoint, + api_key=payload.api_key, ) return APIBasedExtensionService.save(extension_data) @@ -104,16 +115,7 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.expect( - console_ns.model( - "UpdateAPIBasedExtensionRequest", - { - "name": fields.String(required=True, description="Extension name"), - "api_endpoint": fields.String(required=True, description="API endpoint URL"), - "api_key": fields.String(required=True, description="API key for authentication"), - }, - ) - ) + @console_ns.expect(console_ns.models[UpdateAPIBasedExtensionPayload.__name__]) @console_ns.response(200, "Extension updated successfully", api_based_extension_model) @setup_required @login_required @@ -125,13 +127,13 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) - args = console_ns.payload + payload = UpdateAPIBasedExtensionPayload.model_validate(console_ns.payload or {}) - extension_data_from_db.name = args["name"] - extension_data_from_db.api_endpoint = args["api_endpoint"] + extension_data_from_db.name = payload.name + extension_data_from_db.api_endpoint = payload.api_endpoint - if args["api_key"] != HIDDEN_VALUE: - extension_data_from_db.api_key = args["api_key"] + if payload.api_key != HIDDEN_VALUE: + extension_data_from_db.api_key = payload.api_key return APIBasedExtensionService.save(extension_data_from_db) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 17cfc3ff4b..c545348cde 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,31 +1,39 @@ from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required -from models.model import Tag from services.tag_service import TagService - -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -parser_tags = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=_validate_name, - ) - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") +class TagBasePayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +class TagBindingPayload(BaseModel): + tag_ids: list[str] + target_id: str + type: Literal["knowledge", "app"] | None = None + + +class TagBindingRemovePayload(BaseModel): + tag_id: str + target_id: str + type: Literal["knowledge", "app"] | None = None + + +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, ) @@ -43,7 +51,7 @@ class TagListApi(Resource): return tags, 200 - @console_ns.expect(parser_tags) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -53,22 +61,17 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tags.parse_args() - tag = TagService.save_tags(args) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 -parser_tag_id = reqparse.RequestParser().add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name -) - - @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(parser_tag_id) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,8 +82,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tag_id.parse_args() - tag = TagService.update_tags(args, tag_id) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -100,17 +103,9 @@ class TagUpdateDeleteApi(Resource): return 204 -parser_create = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -120,23 +115,15 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_create.parse_args() - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) return {"result": "success"}, 200 -parser_remove = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): - @console_ns.expect(parser_remove) + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -146,7 +133,7 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_remove.parse_args() - TagService.delete_tag_binding(args) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 9bf393ea2e..588bff494b 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,5 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.console import console_ns @@ -9,6 +10,20 @@ from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class LoadBalancingCredentialPayload(BaseModel): + model: str + model_type: ModelType + credentials: dict + + +console_ns.schema_model( + LoadBalancingCredentialPayload.__name__, + LoadBalancingCredentialPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route( "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate" @@ -24,20 +39,7 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {}) # validate model load balancing credentials model_load_balancing_service = ModelLoadBalancingService() @@ -49,9 +51,9 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=payload.model, + model_type=payload.model_type, + credentials=payload.credentials, ) except CredentialsValidateFailedError as ex: result = False @@ -79,20 +81,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {}) # validate model load balancing config credentials model_load_balancing_service = ModelLoadBalancingService() @@ -104,9 +93,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=payload.model, + model_type=payload.model_type, + credentials=payload.credentials, config_id=config_id, ) except CredentialsValidateFailedError as ex: diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2c54aa5a20..d9c1220b78 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,15 +1,15 @@ import io +from typing import Any, Literal from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_restx import ( - Resource, - reqparse, -) +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -25,7 +25,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db -from libs.helper import StrLen, alphanumeric, uuid_value +from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -48,24 +48,216 @@ def is_valid_url(url: str) -> bool: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] except (ValueError, TypeError): - # ValueError: Invalid URL format - # TypeError: url is not a string return False -parser_tool = reqparse.RequestParser().add_argument( - "type", - type=str, - choices=["builtin", "model", "api", "workflow", "mcp"], - required=False, - nullable=True, - location="args", +class ToolProviderListQuery(BaseModel): + type: Literal["builtin", "model", "api", "workflow", "mcp"] | None = None + + +class BuiltinToolCredentialDeletePayload(BaseModel): + credential_id: str + + +class BuiltinToolAddPayload(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + type: CredentialType + + +class BuiltinToolUpdatePayload(BaseModel): + credential_id: str + credentials: dict[str, Any] | None = None + name: str | None = Field(default=None, max_length=30) + + +class ApiToolProviderBasePayload(BaseModel): + credentials: dict[str, Any] + schema_type: str + schema: str + provider: str + icon: dict[str, Any] + privacy_policy: str | None = None + labels: list[str] | None = None + custom_disclaimer: str | None = None + + +class ApiToolProviderAddPayload(ApiToolProviderBasePayload): + pass + + +class ApiToolProviderUpdatePayload(ApiToolProviderBasePayload): + original_provider: str + + +class UrlQuery(BaseModel): + url: str + + @field_validator("url") + @classmethod + def validate_url(cls, value: str) -> str: + if not is_valid_url(value): + raise ValueError("Invalid URL") + return value + + +class ProviderQuery(BaseModel): + provider: str + + +class ApiToolProviderDeletePayload(BaseModel): + provider: str + + +class ApiToolSchemaPayload(BaseModel): + schema: str + + +class ApiToolTestPayload(BaseModel): + tool_name: str + provider_name: str | None = None + credentials: dict[str, Any] + parameters: dict[str, Any] + schema_type: str + schema: str + + +class WorkflowToolBasePayload(BaseModel): + name: str + label: str + description: str + icon: dict[str, Any] + parameters: list[dict[str, Any]] + privacy_policy: str | None = "" + labels: list[str] | None = None + + @field_validator("name") + @classmethod + def validate_name(cls, value: str) -> str: + return alphanumeric(value) + + +class WorkflowToolCreatePayload(WorkflowToolBasePayload): + workflow_app_id: str + + @field_validator("workflow_app_id") + @classmethod + def validate_workflow_app_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolUpdatePayload(WorkflowToolBasePayload): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolDeletePayload(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolGetQuery(BaseModel): + workflow_tool_id: str | None = None + workflow_app_id: str | None = None + + @field_validator("workflow_tool_id", "workflow_app_id") + @classmethod + def validate_ids(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + @model_validator(mode="after") + def ensure_one(self) -> "WorkflowToolGetQuery": + if not self.workflow_tool_id and not self.workflow_app_id: + raise ValueError("workflow_tool_id or workflow_app_id is required") + return self + + +class WorkflowToolListQuery(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class BuiltinProviderDefaultCredentialPayload(BaseModel): + id: str + + +class ToolOAuthCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = True + + +class MCPProviderBasePayload(BaseModel): + server_url: str + name: str + icon: str + icon_type: str + icon_background: str | None = "" + server_identifier: str + configuration: dict[str, Any] | None = Field(default_factory=dict) + headers: dict[str, Any] | None = Field(default_factory=dict) + authentication: dict[str, Any] | None = Field(default_factory=dict) + + +class MCPProviderCreatePayload(MCPProviderBasePayload): + pass + + +class MCPProviderUpdatePayload(MCPProviderBasePayload): + provider_id: str + + +class MCPProviderDeletePayload(BaseModel): + provider_id: str + + +class MCPAuthPayload(BaseModel): + provider_id: str + authorization_code: str | None = None + + +class MCPCallbackQuery(BaseModel): + code: str + state: str + + +register_schema_models( + console_ns, + BuiltinToolCredentialDeletePayload, + BuiltinToolAddPayload, + BuiltinToolUpdatePayload, + ApiToolProviderAddPayload, + ApiToolProviderUpdatePayload, + ApiToolProviderDeletePayload, + ApiToolSchemaPayload, + ApiToolTestPayload, + WorkflowToolCreatePayload, + WorkflowToolUpdatePayload, + WorkflowToolDeletePayload, + BuiltinProviderDefaultCredentialPayload, + ToolOAuthCustomClientPayload, + MCPProviderCreatePayload, + MCPProviderUpdatePayload, + MCPProviderDeletePayload, + MCPAuthPayload, ) @console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): - @console_ns.expect(parser_tool) @setup_required @login_required @account_initialization_required @@ -74,9 +266,10 @@ class ToolProviderListApi(Resource): user_id = user.id - args = parser_tool.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -106,14 +299,9 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) -parser_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[BuiltinToolCredentialDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -121,26 +309,18 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): _, tenant_id = current_account_with_tenant() - args = parser_delete.parse_args() + payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( tenant_id, provider, - args["credential_id"], + payload.credential_id, ) -parser_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - .add_argument("type", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): - @console_ns.expect(parser_add) + @console_ns.expect(console_ns.models[BuiltinToolAddPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -149,32 +329,21 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - args = parser_add.parse_args() - - if args["type"] not in CredentialType.values(): - raise ValueError(f"Invalid credential type: {args['type']}") + payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credentials=args["credentials"], - name=args["name"], - api_type=CredentialType.of(args["type"]), + credentials=payload.credentials, + name=payload.name, + api_type=CredentialType.of(payload.type), ) -parser_update = ( - reqparse.RequestParser() - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[BuiltinToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -183,15 +352,15 @@ class ToolBuiltinProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_update.parse_args() + payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credential_id=args["credential_id"], - credentials=args.get("credentials", None), - name=args.get("name", ""), + credential_id=payload.credential_id, + credentials=payload.credentials, + name=payload.name or "", ) return result @@ -221,22 +390,9 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -parser_api_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): - @console_ns.expect(parser_api_add) + @console_ns.expect(console_ns.models[ApiToolProviderAddPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -246,28 +402,24 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - args = parser_api_add.parse_args() + payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args["provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args.get("privacy_policy", ""), - args.get("custom_disclaimer", ""), - args.get("labels", []), + payload.provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema, + payload.privacy_policy or "", + payload.custom_disclaimer or "", + payload.labels or [], ) -parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): - @console_ns.expect(parser_remote) @setup_required @login_required @account_initialization_required @@ -276,23 +428,18 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - args = parser_remote.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, - args["url"], + query.url, ) -parser_tools = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): - @console_ns.expect(parser_tools) @setup_required @login_required @account_initialization_required @@ -301,34 +448,21 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - args = parser_tools.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, - args["provider"], + query.provider, ) ) -parser_api_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("original_provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") - .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): - @console_ns.expect(parser_api_update) + @console_ns.expect(console_ns.models[ApiToolProviderUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -338,31 +472,26 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - args = parser_api_update.parse_args() + payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args["provider"], - args["original_provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args["privacy_policy"], - args["custom_disclaimer"], - args.get("labels", []), + payload.provider, + payload.original_provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema, + payload.privacy_policy, + payload.custom_disclaimer, + payload.labels or [], ) -parser_api_delete = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): - @console_ns.expect(parser_api_delete) + @console_ns.expect(console_ns.models[ApiToolProviderDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -372,21 +501,17 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - args = parser_api_delete.parse_args() + payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args["provider"], + payload.provider, ) -parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): - @console_ns.expect(parser_get) @setup_required @login_required @account_initialization_required @@ -395,12 +520,13 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - args = parser_get.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args["provider"], + query.provider, ) @@ -419,72 +545,43 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) -parser_schema = reqparse.RequestParser().add_argument( - "schema", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): - @console_ns.expect(parser_schema) + @console_ns.expect(console_ns.models[ApiToolSchemaPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_schema.parse_args() + payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=args["schema"], + schema=payload.schema, ) -parser_pre = ( - reqparse.RequestParser() - .add_argument("tool_name", type=str, required=True, nullable=False, location="json") - .add_argument("provider_name", type=str, required=False, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): - @console_ns.expect(parser_pre) + @console_ns.expect(console_ns.models[ApiToolTestPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_pre.parse_args() + payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, - args["provider_name"] or "", - args["tool_name"], - args["credentials"], - args["parameters"], - args["schema_type"], - args["schema"], + payload.provider_name or "", + payload.tool_name, + payload.credentials, + payload.parameters, + payload.schema_type, + payload.schema, ) -parser_create = ( - reqparse.RequestParser() - .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[WorkflowToolCreatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -494,38 +591,25 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - args = parser_create.parse_args() + payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, - workflow_app_id=args["workflow_app_id"], - name=args["name"], - label=args["label"], - icon=args["icon"], - description=args["description"], - parameters=args["parameters"], - privacy_policy=args["privacy_policy"], - labels=args["labels"], + workflow_app_id=payload.workflow_app_id, + name=payload.name, + label=payload.label, + icon=payload.icon, + description=payload.description, + parameters=payload.parameters, + privacy_policy=payload.privacy_policy or "", + labels=payload.labels, ) -parser_workflow_update = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): - @console_ns.expect(parser_workflow_update) + @console_ns.expect(console_ns.models[WorkflowToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -534,33 +618,25 @@ class ToolWorkflowProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_workflow_update.parse_args() - - if not args["workflow_tool_id"]: - raise ValueError("incorrect workflow_tool_id") + payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + payload.workflow_tool_id, + payload.name, + payload.label, + payload.icon, + payload.description, + payload.parameters, + payload.privacy_policy or "", + payload.labels or [], ) -parser_workflow_delete = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): - @console_ns.expect(parser_workflow_delete) + @console_ns.expect(console_ns.models[WorkflowToolDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -570,25 +646,17 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - args = parser_workflow_delete.parse_args() + payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], + payload.workflow_tool_id, ) -parser_wf_get = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): - @console_ns.expect(parser_wf_get) @setup_required @login_required @account_initialization_required @@ -597,19 +665,20 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - args = parser_wf_get.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = WorkflowToolGetQuery.model_validate(raw_args) - if args.get("workflow_tool_id"): + if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) - elif args.get("workflow_app_id"): + elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args["workflow_app_id"], + query.workflow_app_id, ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") @@ -617,14 +686,8 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) -parser_wf_tools = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): - @console_ns.expect(parser_wf_tools) @setup_required @login_required @account_initialization_required @@ -633,13 +696,14 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - args = parser_wf_tools.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) ) @@ -806,49 +870,39 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_default_cred = reqparse.RequestParser().add_argument( - "id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): - @console_ns.expect(parser_default_cred) + @console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - args = parser_default_cred.parse_args() + payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id ) -parser_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): - @console_ns.expect(parser_custom) + @console_ns.expect(console_ns.models[ToolOAuthCustomClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - args = parser_custom.parse_args() + payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, - client_params=args.get("client_params", {}), - enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + client_params=payload.client_params or {}, + enable_oauth_custom_client=payload.enable_oauth_custom_client + if payload.enable_oauth_custom_client is not None + else True, ) @setup_required @@ -900,49 +954,19 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) -parser_mcp = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_put = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_delete = reqparse.RequestParser().add_argument( - "provider_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): - @console_ns.expect(parser_mcp) + @console_ns.expect(console_ns.models[MCPProviderCreatePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_mcp.parse_args() + payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) user, tenant_id = current_account_with_tenant() # Parse and validate models - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None # Create provider with Session(db.engine) as session, session.begin(): @@ -950,26 +974,26 @@ class ToolProviderMCPApi(Resource): result = service.create_provider( tenant_id=tenant_id, user_id=user.id, - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background or "", + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, ) return jsonable_encoder(result) - @console_ns.expect(parser_mcp_put) + @console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self): - args = parser_mcp_put.parse_args() - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None _, current_tenant_id = current_account_with_tenant() # Step 1: Validate server URL change if needed (includes URL format validation and network operation) @@ -977,7 +1001,9 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) validation_result = service.validate_server_url_change( - tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + tenant_id=current_tenant_id, + provider_id=payload.provider_id, + new_server_url=payload.server_url, ) # No need to check for errors here, exceptions will be raised directly @@ -987,50 +1013,43 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + provider_id=payload.provider_id, + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, validation_result=validation_result, ) return {"result": "success"} - @console_ns.expect(parser_mcp_delete) + @console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def delete(self): - args = parser_mcp_delete.parse_args() + payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) return {"result": "success"} -parser_auth = ( - reqparse.RequestParser() - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): - @console_ns.expect(parser_auth) + @console_ns.expect(console_ns.models[MCPAuthPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_auth.parse_args() - provider_id = args["provider_id"] + payload = MCPAuthPayload.model_validate(console_ns.payload or {}) + provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): @@ -1068,7 +1087,7 @@ class ToolMCPAuthApi(Resource): # Pass the extracted OAuth metadata hints to auth() auth_result = auth( provider_entity, - args.get("authorization_code"), + payload.authorization_code, resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) @@ -1133,20 +1152,13 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -parser_cb = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, nullable=False, location="args") - .add_argument("state", type=str, required=True, nullable=False, location="args") -) - - @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): - @console_ns.expect(parser_cb) def get(self): - args = parser_cb.parse_args() - state_key = args["state"] - authorization_code = args["code"] + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = MCPCallbackQuery.model_validate(raw_args) + state_key = query.state + authorization_code = query.code # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 268473d6d1..30fb954755 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -1,11 +1,14 @@ import logging +from typing import Any from flask import make_response, redirect, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.web.error import NotFoundError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType @@ -32,6 +35,35 @@ from ..wraps import ( logger = logging.getLogger(__name__) +class TriggerSubscriptionBuilderCreatePayload(BaseModel): + credential_type: str | None = None + + +class TriggerSubscriptionBuilderVerifyPayload(BaseModel): + credentials: dict[str, Any] | None = None + + +class TriggerSubscriptionBuilderUpdatePayload(BaseModel): + name: str | None = None + parameters: dict[str, Any] | None = None + properties: dict[str, Any] | None = None + credentials: dict[str, Any] | None = None + + +class TriggerOAuthClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enabled: bool | None = None + + +register_schema_models( + console_ns, + TriggerSubscriptionBuilderCreatePayload, + TriggerSubscriptionBuilderVerifyPayload, + TriggerSubscriptionBuilderUpdatePayload, + TriggerOAuthClientPayload, +) + + @console_ns.route("/workspaces/current/trigger-provider//icon") class TriggerProviderIconApi(Resource): @setup_required @@ -97,16 +129,11 @@ class TriggerSubscriptionListApi(Resource): raise -parser = reqparse.RequestParser().add_argument( - "credential_type", type=str, required=False, nullable=True, location="json" -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/create", ) class TriggerSubscriptionBuilderCreateApi(Resource): - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -116,10 +143,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser.parse_args() + payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {}) try: - credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value) + credential_type = CredentialType.of(payload.credential_type or CredentialType.UNAUTHORIZED.value) subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, @@ -147,18 +174,11 @@ class TriggerSubscriptionBuilderGetApi(Resource): ) -parser_api = ( - reqparse.RequestParser() - # The credentials of the subscription builder - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/verify/", ) class TriggerSubscriptionBuilderVerifyApi(Resource): - @console_ns.expect(parser_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -168,7 +188,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser_api.parse_args() + payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_verify to prevent race conditions @@ -178,7 +198,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - credentials=args.get("credentials", None), + credentials=payload.credentials, ), ) except Exception as e: @@ -186,24 +206,11 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): raise ValueError(str(e)) from e -parser_update_api = ( - reqparse.RequestParser() - # The name of the subscription builder - .add_argument("name", type=str, required=False, nullable=True, location="json") - # The parameters of the subscription builder - .add_argument("parameters", type=dict, required=False, nullable=True, location="json") - # The properties of the subscription builder - .add_argument("properties", type=dict, required=False, nullable=True, location="json") - # The credentials of the subscription builder - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/update/", ) class TriggerSubscriptionBuilderUpdateApi(Resource): - @console_ns.expect(parser_update_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -214,7 +221,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): assert isinstance(user, Account) assert user.current_tenant_id is not None - args = parser_update_api.parse_args() + payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: return jsonable_encoder( TriggerSubscriptionBuilderService.update_trigger_subscription_builder( @@ -222,10 +229,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), - credentials=args.get("credentials", None), + name=payload.name, + parameters=payload.parameters, + properties=payload.properties, + credentials=payload.credentials, ), ) ) @@ -260,7 +267,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource): "/workspaces/current/trigger-provider//subscriptions/builder/build/", ) class TriggerSubscriptionBuilderBuildApi(Resource): - @console_ns.expect(parser_update_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -269,7 +276,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): """Build a subscription instance for a trigger provider""" user = current_user assert user.current_tenant_id is not None - args = parser_update_api.parse_args() + payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_build to prevent race conditions TriggerSubscriptionBuilderService.update_and_build_builder( @@ -278,9 +285,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), + name=payload.name, + parameters=payload.parameters, + properties=payload.properties, ), ) return 200 @@ -474,13 +481,6 @@ class TriggerOAuthCallbackApi(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_oauth_client = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enabled", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/trigger-provider//oauth/client") class TriggerOAuthClientManageApi(Resource): @setup_required @@ -528,7 +528,7 @@ class TriggerOAuthClientManageApi(Resource): logger.exception("Error getting OAuth client", exc_info=e) raise - @console_ns.expect(parser_oauth_client) + @console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -538,15 +538,15 @@ class TriggerOAuthClientManageApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser_oauth_client.parse_args() + payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {}) try: provider_id = TriggerProviderID(provider) return TriggerProviderService.save_custom_oauth_client_params( tenant_id=user.current_tenant_id, provider_id=provider_id, - client_params=args.get("client_params"), - enabled=args.get("enabled"), + client_params=payload.client_params, + enabled=payload.enabled, ) except ValueError as e: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 60193f5f15..9721003eab 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, ConfigDict, Field from werkzeug.exceptions import Unauthorized from constants import HEADER_NAME_APP_CODE @@ -21,6 +22,13 @@ from services.webapp_auth_service import WebAppAuthService logger = logging.getLogger(__name__) +class AppAccessModeQuery(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + app_id: str | None = Field(default=None, alias="appId", description="Application ID") + app_code: str | None = Field(default=None, alias="appCode", description="Application code") + + @web_ns.route("/parameters") class AppParameterApi(WebApiResource): """Resource for app variables.""" @@ -82,12 +90,7 @@ class AppMeta(WebApiResource): class AppAccessMode(Resource): @web_ns.doc("Get App Access Mode") @web_ns.doc(description="Retrieve the access mode for a web application (public or restricted).") - @web_ns.doc( - params={ - "appId": {"description": "Application ID", "type": "string", "required": False}, - "appCode": {"description": "Application code", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[AppAccessModeQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -96,21 +99,16 @@ class AppAccessMode(Resource): } ) def get(self): - parser = ( - reqparse.RequestParser() - .add_argument("appId", type=str, required=False, location="args") - .add_argument("appCode", type=str, required=False, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + args = AppAccessModeQuery.model_validate(raw_args) features = FeatureService.get_system_features() if not features.webapp_auth.enabled: return {"accessMode": "public"} - app_id = args.get("appId") - if args.get("appCode"): - app_code = args["appCode"] - app_id = AppService.get_app_id_by_code(app_code) + app_id = args.app_id + if args.app_code: + app_id = AppService.get_app_id_by_code(args.app_code) if not app_id: raise ValueError("appId or appCode must be provided") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b9fef48c4d..30bd07847e 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import fields, marshal_with, reqparse +from flask_restx import fields, marshal_with +from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError import services @@ -17,9 +18,11 @@ from controllers.web.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) +from ..common import register_schema_models from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService from services.errors.audio import ( @@ -29,6 +32,22 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) + +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + +register_schema_models(web_ns, TextToAudioPayload) + logger = logging.getLogger(__name__) @@ -88,6 +107,7 @@ class AudioApi(WebApiResource): @web_ns.route("/text-to-audio") class TextApi(WebApiResource): + @web_ns.expect(web_ns.models[TextToAudioPayload.__name__]) @web_ns.doc("Text to Audio") @web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc( @@ -102,18 +122,11 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(web_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e8a4698375..830d2a16b5 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging +from typing import Any, Literal -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services @@ -34,25 +35,41 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the completion") + query: str = Field(description="Query text for completion") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + retriever_from: str | None = Field(default="web_app", description="Source of retriever") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the chat") + query: str = Field(description="User query/message") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + retriever_from: str | None = Field(default="web_app", description="Source of retriever") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): @web_ns.doc("Create Completion Message") @web_ns.doc(description="Create a completion message for text generation applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, - "query": {"description": "Query text for completion", "type": "string", "required": False}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -67,18 +84,10 @@ class CompletionApi(WebApiResource): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = CompletionMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: @@ -142,22 +151,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): @web_ns.doc("Create Chat Message") @web_ns.doc(description="Create a chat message for conversational applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, - "query": {"description": "User query/message", "type": "string", "required": True}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False}, - "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False}, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[ChatMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -173,20 +167,10 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = ChatMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 86e19423e5..3281e0b001 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,8 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from typing import Literal + +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -16,6 +19,25 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -60,25 +82,8 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: @@ -86,11 +91,11 @@ class ConversationListApi(WebApiResource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -163,15 +168,10 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) + return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index b9e391e049..6f22171c67 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -2,7 +2,8 @@ import base64 import secrets from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,12 +19,34 @@ from controllers.console.error import EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.web import web_ns from extensions.ext_database import db -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password from models import Account from services.account_service import AccountService +class ForgotPasswordSendPayload(BaseModel): + email: EmailStr + language: str | None = None + + +class ForgotPasswordCheckPayload(BaseModel): + email: EmailStr + code: str + token: str = Field(min_length=1) + + +class ForgotPasswordResetPayload(BaseModel): + token: str = Field(min_length=1) + new_password: str + password_confirm: str + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + @web_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): @only_edition_enterprise @@ -40,29 +63,24 @@ class ForgotPasswordSendEmailApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {}) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if payload.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none() token = None if account is None: raise AuthenticationFailedError() else: - token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language) return {"result": "success", "data": token} @@ -78,40 +96,34 @@ class ForgotPasswordCheckApi(Resource): responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"} ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {}) - user_email = args["email"] + user_email = payload.email - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() - token_data = AccountService.get_reset_password_data(args["token"]) + token_data = AccountService.get_reset_password_data(payload.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args["email"]) + if payload.code != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(payload.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(payload.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args["code"], additional_data={"phase": "reset"} + user_email, code=payload.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) + AccountService.reset_forgot_password_error_rate_limit(payload.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -131,20 +143,14 @@ class ForgotPasswordResetApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {}) # Validate passwords match - if args["new_password"] != args["password_confirm"]: + if payload.new_password != payload.password_confirm: raise PasswordMismatchError() # Validate token and get reset data - reset_data = AccountService.get_reset_password_data(args["token"]) + reset_data = AccountService.get_reset_password_data(payload.token) if not reset_data: raise InvalidTokenError() # Must use token in reset phase @@ -152,11 +158,11 @@ class ForgotPasswordResetApi(Resource): raise InvalidTokenError() # Revoke token to prevent reuse - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(payload.token) # Generate secure salt and hash password salt = secrets.token_bytes(16) - password_hashed = hash_password(args["new_password"], salt) + password_hashed = hash_password(payload.new_password, salt) email = reset_data.get("email", "") diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 538d0c44be..861ca27fc9 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,6 +1,7 @@ from flask import make_response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource from jwt import InvalidTokenError +from pydantic import BaseModel, Field, field_validator import services from configs import dify_config @@ -13,7 +14,7 @@ from controllers.console.error import AccountBannedError from controllers.console.wraps import only_edition_enterprise, setup_required from controllers.web import web_ns from controllers.web.wraps import decode_jwt_token -from libs.helper import email +from libs.helper import EmailStr from libs.passport import PassportService from libs.password import valid_password from libs.token import ( @@ -25,6 +26,27 @@ from services.app_service import AppService from services.webapp_auth_service import WebAppAuthService +class LoginPayload(BaseModel): + email: EmailStr + password: str + + @field_validator("password") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +class EmailCodeLoginSendPayload(BaseModel): + email: EmailStr + language: str | None = None + + +class EmailCodeLoginVerifyPayload(BaseModel): + email: EmailStr + code: str + token: str = Field(min_length=1) + + @web_ns.route("/login") class LoginApi(Resource): """Resource for web app email/password login.""" @@ -44,15 +66,10 @@ class LoginApi(Resource): ) def post(self): """Authenticate user and login.""" - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("password", type=valid_password, required=True, location="json") - ) - args = parser.parse_args() + payload = LoginPayload.model_validate(web_ns.payload or {}) try: - account = WebAppAuthService.authenticate(args["email"], args["password"]) + account = WebAppAuthService.authenticate(payload.email, payload.password) except services.errors.account.AccountLoginError: raise AccountBannedError() except services.errors.account.AccountPasswordError: @@ -139,6 +156,7 @@ class EmailCodeLoginSendEmailApi(Resource): @only_edition_enterprise @web_ns.doc("send_email_code_login") @web_ns.doc(description="Send email verification code for login") + @web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__]) @web_ns.doc( responses={ 200: "Email code sent successfully", @@ -147,19 +165,14 @@ class EmailCodeLoginSendEmailApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {}) - if args["language"] is not None and args["language"] == "zh-Hans": + if payload.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" - account = WebAppAuthService.get_user_through_email(args["email"]) + account = WebAppAuthService.get_user_through_email(payload.email) if account is None: raise AuthenticationFailedError() else: @@ -173,6 +186,7 @@ class EmailCodeLoginApi(Resource): @only_edition_enterprise @web_ns.doc("verify_email_code_login") @web_ns.doc(description="Verify email code and complete login") + @web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__]) @web_ns.doc( responses={ 200: "Email code verified and login successful", @@ -182,33 +196,27 @@ class EmailCodeLoginApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, location="json") - ) - args = parser.parse_args() + payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {}) - user_email = args["email"] + user_email = payload.email - token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + token_data = WebAppAuthService.get_email_code_login_data(payload.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + if token_data["email"] != payload.email: raise InvalidEmailError() - if token_data["code"] != args["code"]: + if token_data["code"] != payload.code: raise EmailCodeError() - WebAppAuthService.revoke_email_code_login_token(args["token"]) + WebAppAuthService.revoke_email_code_login_token(payload.token) account = WebAppAuthService.get_user_through_email(user_email) if not account: raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(payload.email) response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 9f9aa4838c..fdd9352972 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,9 @@ import logging +from typing import Literal -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound from controllers.web import web_ns @@ -38,6 +40,30 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +class MessageListQuery(BaseModel): + conversation_id: str = Field(description="Conversation UUID") + first_id: str | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") + + @field_validator("conversation_id", "first_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") + + +class MessageMoreLikeThisQuery(BaseModel): + response_mode: Literal["blocking", "streaming"] = Field( + description="Response mode", + ) + + @web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { @@ -65,18 +91,7 @@ class MessageListApi(WebApiResource): @web_ns.doc("Get Message List") @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") - @web_ns.doc( - params={ - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, - "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, - "limit": { - "description": "Number of messages to return (1-100)", - "type": "integer", - "required": False, - "default": 20, - }, - } - ) + @web_ns.expect(web_ns.models[MessageListQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -93,17 +108,12 @@ class MessageListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args") - .add_argument("first_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = MessageListQuery.model_validate(raw_args) try: return MessageService.pagination_by_first_id( - app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, end_user, query.conversation_id, query.first_id, query.limit ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -120,17 +130,7 @@ class MessageFeedbackApi(WebApiResource): @web_ns.doc("Create Message Feedback") @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) - @web_ns.doc( - params={ - "rating": { - "description": "Feedback rating", - "type": "string", - "enum": ["like", "dislike"], - "required": False, - }, - "content": {"description": "Feedback content/comment", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[MessageFeedbackPayload.__name__]) @web_ns.doc( responses={ 200: "Feedback submitted successfully", @@ -145,20 +145,15 @@ class MessageFeedbackApi(WebApiResource): def post(self, app_model, end_user, message_id): message_id = str(message_id) - parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - .add_argument("content", type=str, location="json", default=None) - ) - args = parser.parse_args() + payload = MessageFeedbackPayload.model_validate(web_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=end_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -170,17 +165,7 @@ class MessageFeedbackApi(WebApiResource): class MessageMoreLikeThisApi(WebApiResource): @web_ns.doc("Generate More Like This") @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") - @web_ns.doc( - params={ - "message_id": {"description": "Message UUID", "type": "string", "required": True}, - "response_mode": { - "description": "Response mode", - "type": "string", - "enum": ["blocking", "streaming"], - "required": True, - }, - } - ) + @web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -197,12 +182,10 @@ class MessageMoreLikeThisApi(WebApiResource): message_id = str(message_id) - parser = reqparse.RequestParser().add_argument( - "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = MessageMoreLikeThisQuery.model_validate(raw_args) - streaming = args["response_mode"] == "streaming" + streaming = query.response_mode == "streaming" try: response = AppGenerateService.generate_more_like_this( diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index dac4b3da94..fed8df40b0 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,8 @@ import urllib.parse import httpx -from flask_restx import marshal_with, reqparse +from flask_restx import marshal_with +from pydantic import BaseModel, Field, HttpUrl import services from controllers.common import helpers @@ -19,6 +20,10 @@ from fields.file_fields import build_file_with_signed_url_model, build_remote_fi from services.file_service import FileService +class RemoteFileUploadPayload(BaseModel): + url: HttpUrl = Field(description="Remote file URL") + + @web_ns.route("/remote-files/") class RemoteFileInfoApi(WebApiResource): @web_ns.doc("get_remote_file_info") @@ -97,10 +102,8 @@ class RemoteFileUploadApi(WebApiResource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") - args = parser.parse_args() - - url = args["url"] + payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {}) + url = str(payload.url) try: resp = ssrf_proxy.head(url=url) diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 865f3610a7..f407507883 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,6 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import NotFound from controllers.web import web_ns @@ -23,6 +24,27 @@ message_fields = { } +class SavedMessageListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class SavedMessageCreatePayload(BaseModel): + message_id: str + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str) -> str: + return uuid_value(value) + + @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { @@ -63,14 +85,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = SavedMessageListQuery.model_validate(raw_args) - return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) @web_ns.doc("Save Message") @web_ns.doc(description="Save a specific message for later reference.") @@ -94,11 +112,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 3cbb07a296..136f93c574 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,7 @@ import logging +from typing import Any -from flask_restx import reqparse +from pydantic import BaseModel from werkzeug.exceptions import InternalServerError from controllers.web import web_ns @@ -27,6 +28,12 @@ from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError + +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + + logger = logging.getLogger(__name__) @@ -58,12 +65,8 @@ class WorkflowRunApi(WebApiResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload = WorkflowRunPayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = AppGenerateService.generate( From 995bfe16fc386e0744a734ea77728d7fa1a97ebc Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 20:48:57 +0900 Subject: [PATCH 02/28] fix fix type fix --- api/controllers/console/extension.py | 7 ++-- api/controllers/console/tag/tags.py | 3 ++ .../console/workspace/tool_providers.py | 37 ++++++++++--------- api/controllers/web/app.py | 13 +++++-- api/controllers/web/audio.py | 6 ++- api/controllers/web/conversation.py | 4 +- api/controllers/web/message.py | 4 +- api/controllers/web/saved_message.py | 2 +- .../utils/workflow_configuration_sync.py | 5 --- .../tools/api_tools_manage_service.py | 4 +- .../tools/workflow_tools_manage_service.py | 16 +++----- 11 files changed, 51 insertions(+), 50 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index c2cb4578ea..e333f5b129 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -3,15 +3,16 @@ from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from constants import HIDDEN_VALUE -from controllers.console import console_ns -from controllers.common.schema import register_schema_models -from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +from ..common.schema import register_schema_models +from . import console_ns +from .wraps import account_initialization_required, setup_required + class CodeBasedExtensionQuery(BaseModel): module: str diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index c545348cde..108fa88d05 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,8 +1,11 @@ +from typing import Literal + from flask import request from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d9c1220b78..8fa52608de 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -24,6 +24,7 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required @@ -74,12 +75,12 @@ class BuiltinToolUpdatePayload(BaseModel): class ApiToolProviderBasePayload(BaseModel): credentials: dict[str, Any] schema_type: str - schema: str + schema_: str = Field(alias="schema") provider: str icon: dict[str, Any] privacy_policy: str | None = None labels: list[str] | None = None - custom_disclaimer: str | None = None + custom_disclaimer: str = "" class ApiToolProviderAddPayload(ApiToolProviderBasePayload): @@ -110,7 +111,7 @@ class ApiToolProviderDeletePayload(BaseModel): class ApiToolSchemaPayload(BaseModel): - schema: str + schema_: str = Field(alias="schema") class ApiToolTestPayload(BaseModel): @@ -119,7 +120,7 @@ class ApiToolTestPayload(BaseModel): credentials: dict[str, Any] parameters: dict[str, Any] schema_type: str - schema: str + schema_: str = Field(alias="schema") class WorkflowToolBasePayload(BaseModel): @@ -127,7 +128,7 @@ class WorkflowToolBasePayload(BaseModel): label: str description: str icon: dict[str, Any] - parameters: list[dict[str, Any]] + parameters: list[WorkflowToolParameterConfiguration] privacy_policy: str | None = "" labels: list[str] | None = None @@ -205,7 +206,7 @@ class MCPProviderBasePayload(BaseModel): name: str icon: str icon_type: str - icon_background: str | None = "" + icon_background: str = "" server_identifier: str configuration: dict[str, Any] | None = Field(default_factory=dict) headers: dict[str, Any] | None = Field(default_factory=dict) @@ -266,10 +267,10 @@ class ToolProviderListApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -411,7 +412,7 @@ class ToolApiProviderAddApi(Resource): payload.icon, payload.credentials, payload.schema_type, - payload.schema, + payload.schema_, payload.privacy_policy or "", payload.custom_disclaimer or "", payload.labels or [], @@ -428,7 +429,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( @@ -448,7 +449,7 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( @@ -482,7 +483,7 @@ class ToolApiProviderUpdateApi(Resource): payload.icon, payload.credentials, payload.schema_type, - payload.schema, + payload.schema_, payload.privacy_policy, payload.custom_disclaimer, payload.labels or [], @@ -520,7 +521,7 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( @@ -555,7 +556,7 @@ class ToolApiProviderSchemaApi(Resource): payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=payload.schema, + schema=payload.schema_, ) @@ -575,7 +576,7 @@ class ToolApiProviderPreviousTestApi(Resource): payload.credentials, payload.parameters, payload.schema_type, - payload.schema, + payload.schema_, ) @@ -665,7 +666,7 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = WorkflowToolGetQuery.model_validate(raw_args) if query.workflow_tool_id: @@ -696,7 +697,7 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( @@ -1155,7 +1156,7 @@ class ToolMCPUpdateApi(Resource): @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): def get(self): - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = MCPCallbackQuery.model_validate(raw_args) state_key = query.state authorization_code = query.code diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 9721003eab..b8eefcb1e0 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -7,9 +7,7 @@ from werkzeug.exceptions import Unauthorized from constants import HEADER_NAME_APP_CODE from controllers.common import fields -from controllers.web import web_ns -from controllers.web.error import AppUnavailableError -from controllers.web.wraps import WebApiResource +from controllers.common.schema import register_schema_models from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from libs.passport import PassportService from libs.token import extract_webapp_passport @@ -19,6 +17,10 @@ from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService +from . import web_ns +from .error import AppUnavailableError +from .wraps import WebApiResource + logger = logging.getLogger(__name__) @@ -29,6 +31,9 @@ class AppAccessModeQuery(BaseModel): app_code: str | None = Field(default=None, alias="appCode", description="Application code") +register_schema_models(web_ns, AppAccessModeQuery) + + @web_ns.route("/parameters") class AppParameterApi(WebApiResource): """Resource for app variables.""" @@ -99,7 +104,7 @@ class AppAccessMode(Resource): } ) def get(self): - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() args = AppAccessModeQuery.model_validate(raw_args) features = FeatureService.get_system_features() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 30bd07847e..af31fbc334 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -18,8 +18,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) -from ..common import register_schema_models -from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -32,6 +30,9 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from ..common.schema import register_schema_models +from ..web.wraps import WebApiResource + class TextToAudioPayload(BaseModel): message_id: str | None = None @@ -46,6 +47,7 @@ class TextToAudioPayload(BaseModel): return value return uuid_value(value) + register_schema_models(web_ns, TextToAudioPayload) logger = logging.getLogger(__name__) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 3281e0b001..21f48ebb36 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -82,7 +82,7 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ConversationListQuery.model_validate(raw_args) try: @@ -171,7 +171,7 @@ class ConversationRenameApi(WebApiResource): payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) + return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) # type: ignore except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index fdd9352972..1a7327ef26 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -108,7 +108,7 @@ class MessageListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = MessageListQuery.model_validate(raw_args) try: @@ -182,7 +182,7 @@ class MessageMoreLikeThisApi(WebApiResource): message_id = str(message_id) - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = MessageMoreLikeThisQuery.model_validate(raw_args) streaming = query.response_mode == "streaming" diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index f407507883..4d3b190f7c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -85,7 +85,7 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = SavedMessageListQuery.model_validate(raw_args) return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 188da0c32d..6d75df3603 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity class WorkflowToolConfigurationUtils: - @classmethod - def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): - for configuration in configurations: - WorkflowToolParameterConfiguration.model_validate(configuration) - @classmethod def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b3b6e36346..82570afca1 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -247,14 +247,14 @@ class ApiToolManageService: credentials: dict, schema_type: str, schema: str, - privacy_policy: str, + privacy_policy: str | None, custom_disclaimer: str, labels: list[str], ): """ update api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: + if schema_type not in list(ApiProviderSchemaType): raise ValueError(f"invalid schema type {schema}") provider_name = provider_name.strip() diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index fe77ff2dc5..4f2ac94e91 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,8 +1,6 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -39,12 +37,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -77,7 +73,7 @@ class WorkflowToolManageService: label=label, icon=json.dumps(icon), description=description, - parameter_configuration=json.dumps(parameters), + parameter_configuration=json.dumps([p.model_dump() for p in parameters]), privacy_policy=privacy_policy, version=workflow.version, ) @@ -108,7 +104,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): @@ -126,8 +122,6 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -166,7 +160,7 @@ class WorkflowToolManageService: workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.parameter_configuration = json.dumps(obj=[p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() From c00d835d7f7fd6a03ac62b9e1da2440e1deae995 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 21:27:37 +0900 Subject: [PATCH 03/28] reg name --- api/controllers/console/explore/completion.py | 8 ++++---- api/controllers/web/completion.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 5901eca915..a6e5b2822a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -40,7 +40,7 @@ from .. import console_ns logger = logging.getLogger(__name__) -class CompletionMessagePayload(BaseModel): +class CompletionMessageExplorePayload(BaseModel): inputs: dict[str, Any] query: str = "" files: list[dict[str, Any]] | None = None @@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel): raise ValueError("must be a valid UUID") from exc -register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) +register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload) # define completion api for user @@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): - @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) + @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) streaming = payload.response_mode == "streaming" diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 830d2a16b5..005f109557 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, @@ -64,6 +65,9 @@ class ChatMessagePayload(BaseModel): return uuid_value(value) +register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): From 60921d2587bd0db2e299149fd4e04e559471c31c Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 21:43:05 +0900 Subject: [PATCH 04/28] based on gemini's advice --- .../console/workspace/tool_providers.py | 2 +- api/controllers/web/app.py | 7 +++- api/controllers/web/conversation.py | 8 ++++- api/controllers/web/forgot_password.py | 4 +++ api/controllers/web/login.py | 4 +++ api/controllers/web/message.py | 33 +++++++++++++++++-- api/controllers/web/remote_files.py | 9 +++-- api/controllers/web/saved_message.py | 4 +++ api/controllers/web/workflow.py | 3 ++ .../tools/api_tools_manage_service.py | 2 +- 10 files changed, 68 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 8fa52608de..c462467c15 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -128,7 +128,7 @@ class WorkflowToolBasePayload(BaseModel): label: str description: str icon: dict[str, Any] - parameters: list[WorkflowToolParameterConfiguration] + parameters: list[WorkflowToolParameterConfiguration] = Field(default_factory=list) privacy_policy: str | None = "" labels: list[str] | None = None diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index b8eefcb1e0..db3b93a4dc 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -95,7 +95,12 @@ class AppMeta(WebApiResource): class AppAccessMode(Resource): @web_ns.doc("Get App Access Mode") @web_ns.doc(description="Retrieve the access mode for a web application (public or restricted).") - @web_ns.expect(web_ns.models[AppAccessModeQuery.__name__]) + @web_ns.doc( + params={ + "appId": {"description": "Application ID", "type": "string", "required": False}, + "appCode": {"description": "Application code", "type": "string", "required": False}, + } + ) @web_ns.doc( responses={ 200: "Success", diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 21f48ebb36..bf197a6601 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource @@ -38,6 +39,9 @@ class ConversationRenamePayload(BaseModel): auto_generate: bool = False +register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -171,7 +175,9 @@ class ConversationRenameApi(WebApiResource): payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) # type: ignore + return ConversationService.rename( + app_model, conversation_id, end_user, payload.name or "", payload.auto_generate + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 6f22171c67..2b16513dd9 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( AuthenticationFailedError, EmailCodeError, @@ -47,6 +48,9 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) +register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload) + + @web_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): @only_edition_enterprise diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 861ca27fc9..9c0d1973a9 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator import services from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( AuthenticationFailedError, EmailCodeError, @@ -47,6 +48,9 @@ class EmailCodeLoginVerifyPayload(BaseModel): token: str = Field(min_length=1) +register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload) + + @web_ns.route("/login") class LoginApi(Resource): """Resource for web app email/password login.""" diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 1a7327ef26..5c7ea9e69a 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -6,6 +6,7 @@ from flask_restx import fields, marshal_with from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppMoreLikeThisDisabledError, @@ -64,6 +65,9 @@ class MessageMoreLikeThisQuery(BaseModel): ) +register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery) + + @web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { @@ -91,7 +95,22 @@ class MessageListApi(WebApiResource): @web_ns.doc("Get Message List") @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") - @web_ns.expect(web_ns.models[MessageListQuery.__name__]) + @web_ns.doc( + params={ + "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, + "first_id": { + "description": "First message ID for pagination", + "type": "string", + "required": False, + }, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) @web_ns.doc( responses={ 200: "Success", @@ -130,7 +149,17 @@ class MessageFeedbackApi(WebApiResource): @web_ns.doc("Create Message Feedback") @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) - @web_ns.expect(web_ns.models[MessageFeedbackPayload.__name__]) + @web_ns.doc( + params={ + "rating": { + "description": "Feedback rating", + "type": "string", + "enum": ["like", "dislike"], + "required": False, + }, + "content": {"description": "Feedback content", "type": "string", "required": False}, + } + ) @web_ns.doc( responses={ 200: "Feedback submitted successfully", diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index fed8df40b0..c1f976829f 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -11,19 +11,24 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) -from controllers.web import web_ns -from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model from services.file_service import FileService +from ..common.schema import register_schema_models +from . import web_ns +from .wraps import WebApiResource + class RemoteFileUploadPayload(BaseModel): url: HttpUrl = Field(description="Remote file URL") +register_schema_models(web_ns, RemoteFileUploadPayload) + + @web_ns.route("/remote-files/") class RemoteFileInfoApi(WebApiResource): @web_ns.doc("get_remote_file_info") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 4d3b190f7c..f3fe7a0aca 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -3,6 +3,7 @@ from flask_restx import fields, marshal_with from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource @@ -45,6 +46,9 @@ class SavedMessageCreatePayload(BaseModel): return uuid_value(value) +register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) + + @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 136f93c574..a62c27a033 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -4,6 +4,7 @@ from typing import Any from pydantic import BaseModel from werkzeug.exceptions import InternalServerError +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( CompletionRequestError, @@ -36,6 +37,8 @@ class WorkflowRunPayload(BaseModel): logger = logging.getLogger(__name__) +register_schema_models(web_ns, WorkflowRunPayload) + @web_ns.route("/workflows/run") class WorkflowRunApi(WebApiResource): diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 82570afca1..73d9095580 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -254,7 +254,7 @@ class ApiToolManageService: """ update api tool provider """ - if schema_type not in list(ApiProviderSchemaType): + if schema_type not in ApiProviderSchemaType: raise ValueError(f"invalid schema type {schema}") provider_name = provider_name.strip() From d985e448854fc0420cd446033c020c0c9d9922b6 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 23:31:16 +0900 Subject: [PATCH 05/28] Update api/controllers/console/tag/tags.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/tag/tags.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 108fa88d05..472232899a 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -12,9 +12,6 @@ from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required from services.tag_service import TagService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") From d118b78d307d3772c1384b0b4739f8e415d77c60 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:33:13 +0000 Subject: [PATCH 06/28] [autofix.ci] apply automated fixes --- api/controllers/console/tag/tags.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 472232899a..22af108a19 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -12,6 +12,7 @@ from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required from services.tag_service import TagService + class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") From 8fa5e711294be057603b10956ba8129a6aa9b637 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 23:37:25 +0900 Subject: [PATCH 07/28] Update api/services/tools/workflow_tools_manage_service.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/services/tools/workflow_tools_manage_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 4f2ac94e91..aa6ace635b 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -160,7 +160,7 @@ class WorkflowToolManageService: workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(obj=[p.model_dump() for p in parameters]) + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() From 163f56686f624ae5026294869293650607668884 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 23:38:37 +0900 Subject: [PATCH 08/28] Update api/controllers/web/completion.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/web/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 005f109557..e0e969b011 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) class CompletionMessagePayload(BaseModel): inputs: dict[str, Any] = Field(description="Input variables for the completion") - query: str = Field(description="Query text for completion") + query: str = Field(default="", description="Query text for completion") files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") response_mode: Literal["blocking", "streaming"] | None = Field( default=None, description="Response mode: blocking or streaming" From ed4bb86e26515f6492693ade031e658ad6a1cc9a Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 23:39:46 +0900 Subject: [PATCH 09/28] rm | None --- api/controllers/web/completion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e0e969b011..a97d745471 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -43,7 +43,7 @@ class CompletionMessagePayload(BaseModel): response_mode: Literal["blocking", "streaming"] | None = Field( default=None, description="Response mode: blocking or streaming" ) - retriever_from: str | None = Field(default="web_app", description="Source of retriever") + retriever_from: str = Field(default="web_app", description="Source of retriever") class ChatMessagePayload(BaseModel): @@ -55,7 +55,7 @@ class ChatMessagePayload(BaseModel): ) conversation_id: str | None = Field(default=None, description="Conversation ID") parent_message_id: str | None = Field(default=None, description="Parent message ID") - retriever_from: str | None = Field(default="web_app", description="Source of retriever") + retriever_from: str = Field(default="web_app", description="Source of retriever") @field_validator("conversation_id", "parent_message_id") @classmethod From 69ebd37e2a111640c3c2283f3707b78bf168c88e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Dec 2025 23:44:04 +0900 Subject: [PATCH 10/28] fix test [autofix.ci] apply automated fixes fix test fix declare ? --- .../test_workflow_tools_manage_service.py | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 71cedd26c4..8ddacfdc1c 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -3,7 +3,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import ValidationError +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -130,20 +132,24 @@ class TestWorkflowToolManageService: def _create_test_workflow_tool_parameters(self): """Helper method to create valid workflow tool parameters.""" return [ - { - "name": "input_text", - "description": "Input text for processing", - "form": "form", - "type": "string", - "required": True, - }, - { - "name": "output_format", - "description": "Output format specification", - "form": "form", - "type": "select", - "required": False, - }, + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + "description": "Input text for processing", + "form": "form", + "type": "string", + "required": True, + } + ), + WorkflowToolParameterConfiguration.model_validate( + { + "name": "output_format", + "description": "Output format specification", + "form": "form", + "type": "select", + "required": False, + } + ), ] def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -208,7 +214,7 @@ class TestWorkflowToolManageService: assert created_tool_provider.label == tool_label assert created_tool_provider.icon == json.dumps(tool_icon) assert created_tool_provider.description == tool_description - assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) + assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters]) assert created_tool_provider.privacy_policy == tool_privacy_policy assert created_tool_provider.version == workflow.version assert created_tool_provider.user_id == account.id @@ -353,18 +359,9 @@ class TestWorkflowToolManageService: app, account, workflow = self._create_test_app_and_account( db_session_with_containers, mock_external_service_dependencies ) - - # Setup invalid workflow tool parameters (missing required fields) - invalid_parameters = [ - { - "name": "input_text", - # Missing description and form fields - "type": "string", - "required": True, - } - ] # Attempt to create workflow tool with invalid parameters - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValidationError) as exc_info: + # Setup invalid workflow tool parameters (missing required fields) WorkflowToolManageService.create_workflow_tool( user_id=account.id, tenant_id=account.current_tenant.id, @@ -373,7 +370,16 @@ class TestWorkflowToolManageService: label=fake.word(), icon={"type": "emoji", "emoji": "🔧"}, description=fake.text(max_nb_chars=200), - parameters=invalid_parameters, + parameters=[ + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + # Missing description and form fields + "type": "string", + "required": True, + } + ) + ], ) # Verify error message contains validation error @@ -579,11 +585,12 @@ class TestWorkflowToolManageService: # Verify database state was updated db.session.refresh(created_tool) + assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label assert created_tool.icon == json.dumps(updated_tool_icon) assert created_tool.description == updated_tool_description - assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) + assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters]) assert created_tool.privacy_policy == updated_tool_privacy_policy assert created_tool.version == workflow.version assert created_tool.updated_at is not None From 05db04317a5bc85396ef7008ea07cdf5b16bca6f Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 9 Dec 2025 18:27:09 +0900 Subject: [PATCH 11/28] add import linter --- api/.ruff.toml | 10 ++++++++++ .../console/datasets/hit_testing_base.py | 15 +++++---------- .../service_api/dataset/hit_testing.py | 8 ++++++-- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/api/.ruff.toml b/api/.ruff.toml index 7206f7fa0f..81ded773dc 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -49,6 +49,7 @@ select = [ "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module "S311", # suspicious-non-cryptographic-random-usage, + "TID", # flake8-tidy-imports ] @@ -84,6 +85,7 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false + "TID252", # allow relative imports from parent modules ] [lint.per-file-ignores] @@ -112,3 +114,11 @@ allowed-unused-imports = [ "tests.integration_tests", "tests.unit_tests", ] + +[lint.flake8-tidy-imports] + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] +msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] +msg = "Use Pydantic payload/query models instead of reqparse." diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db7c50f422..43d7375472 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -56,15 +56,10 @@ class DatasetsHitTestingBase: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, required=False, location="json") - .add_argument("attachment_ids", type=list, required=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() + def parse_args(payload: dict[str, Any] | None = None) -> dict[str, Any]: + """Validate and return hit-testing arguments from an incoming payload.""" + hit_testing_payload = HitTestingPayload.model_validate(payload or {}) + return hit_testing_payload.model_dump(exclude_none=True) @staticmethod def perform_hit_testing(dataset, args): diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index d81287d56f..97a70f5d0e 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,7 +1,10 @@ -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.common.schema import register_schema_model +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +register_schema_model(service_api_ns, HitTestingPayload) + @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): 404: "Dataset not found", } ) + @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Perform hit testing on a dataset. @@ -24,7 +28,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) From 766443178c850e0951b5660f38e71d8d58156062 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 10 Dec 2025 22:18:09 +0900 Subject: [PATCH 12/28] use type --- api/controllers/console/billing/billing.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 7f907dc420..ac039f9c5d 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,8 +1,9 @@ import base64 +from typing import Literal from flask import request from flask_restx import Resource, fields -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest from controllers.console import console_ns @@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class SubscriptionQuery(BaseModel): - plan: str = Field(..., description="Subscription plan") - interval: str = Field(..., description="Billing interval") - - @field_validator("plan") - @classmethod - def validate_plan(cls, value: str) -> str: - if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]: - raise ValueError("Invalid plan") - return value - - @field_validator("interval") - @classmethod - def validate_interval(cls, value: str) -> str: - if value not in {"month", "year"}: - raise ValueError("Invalid interval") - return value + plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan") + interval: Literal["month", "year"] = Field(..., description="Billing interval") class PartnerTenantsPayload(BaseModel): From 80f63c728f38b1d958fa157b028bfaa06dcf4c37 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 11 Dec 2025 19:11:58 +0900 Subject: [PATCH 13/28] use custom UUIDStrOrEmpty --- .../rag_pipeline/rag_pipeline_workflow.py | 5 ++--- .../console/explore/conversation.py | 4 ++-- api/controllers/console/explore/message.py | 5 +++-- .../console/explore/saved_message.py | 6 +++--- api/controllers/service_api/app/completion.py | 8 ++++---- .../service_api/app/conversation.py | 6 +++--- api/controllers/service_api/app/message.py | 7 +++---- api/controllers/web/saved_message.py | 19 +++---------------- api/libs/helper.py | 3 +++ 9 files changed, 26 insertions(+), 37 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 46d67f0581..22b5b64009 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,7 +1,6 @@ import json import logging from typing import Any, Literal, cast -from uuid import UUID from flask import abort, request from flask_restx import Resource, marshal_with, reqparse # type: ignore @@ -38,7 +37,7 @@ from fields.workflow_run_fields import ( workflow_run_pagination_fields, ) from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel): class WorkflowRunQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 92da591ab4..51995b8b8a 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,5 +1,4 @@ from typing import Any -from uuid import UUID from flask import request from flask_restx import marshal_with @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account from models.model import AppMode @@ -24,7 +24,7 @@ from .. import console_ns class ConversationListQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) pinned: bool | None = None diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 229b7c8865..7976298106 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -26,6 +26,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper +from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -44,8 +45,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 6a9e274a0e..768e0d23f8 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -10,19 +10,19 @@ from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService class SavedMessageListQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) class SavedMessageCreatePayload(BaseModel): - message_id: UUID + message_id: UUIDStrOrEmpty register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b3836f3a47..56f63d8a78 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,10 +1,9 @@ import logging from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -30,6 +29,7 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -52,11 +52,11 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: str | None = Field(default=None, description="Conversation UUID") + conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") - + @field_validator("conversation_id", mode="before") @classmethod def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index be6d837032..eb15f34b0d 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,4 @@ from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -24,12 +23,13 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService class ConversationListQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( default="-updated_at", description="Sort order for conversations" @@ -49,7 +49,7 @@ class ConversationRenamePayload(BaseModel): class ConversationVariablesQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d342f4e661..1f96e35bfd 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,7 +1,6 @@ import json import logging from typing import Literal -from uuid import UUID from flask import request from flask_restx import Namespace, Resource, fields @@ -17,7 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import build_message_file_model from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.raws import FilesContainedField -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -30,8 +29,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index f3fe7a0aca..ce5d2858ce 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,3 +1,4 @@ +from libs.helper import UUIDStrOrEmpty from flask import request from flask_restx import fields, marshal_with from pydantic import BaseModel, Field, field_validator @@ -26,25 +27,11 @@ message_fields = { class SavedMessageListQuery(BaseModel): - last_id: str | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) - @field_validator("last_id") - @classmethod - def validate_last_id(cls, value: str | None) -> str | None: - if value is None: - return value - return uuid_value(value) - - class SavedMessageCreatePayload(BaseModel): - message_id: str - - @field_validator("message_id") - @classmethod - def validate_message_id(cls, value: str) -> str: - return uuid_value(value) - + message_id: UUIDStrOrEmpty register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) diff --git a/api/libs/helper.py b/api/libs/helper.py index 4a7afe0bda..ba02fe1f5d 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -119,6 +119,9 @@ def uuid_value(value: Any) -> str: raise ValueError(error) +UUIDStrOrEmpty = Annotated[str, AfterValidator(uuid_value)] + + def alphanumeric(value: str): # check if the value is alphanumeric and underlined if re.match(r"^[a-zA-Z0-9_]+$", value): From e634d2504916e6dfbe621ec28191846a1a3cb4a3 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 11 Dec 2025 19:12:13 +0900 Subject: [PATCH 14/28] fmt --- api/controllers/console/explore/message.py | 1 - api/controllers/console/explore/saved_message.py | 2 -- api/controllers/web/saved_message.py | 7 ++++--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 7976298106..d596d60b36 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,6 +1,5 @@ import logging from typing import Literal -from uuid import UUID from flask import request from flask_restx import marshal_with diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 768e0d23f8..bc7b8e7651 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,5 +1,3 @@ -from uuid import UUID - from flask import request from flask_restx import fields, marshal_with from pydantic import BaseModel, Field diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index ce5d2858ce..dc3a9c348c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,7 +1,6 @@ -from libs.helper import UUIDStrOrEmpty from flask import request from flask_restx import fields, marshal_with -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models @@ -9,7 +8,7 @@ from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField, UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -30,9 +29,11 @@ class SavedMessageListQuery(BaseModel): last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) + class SavedMessageCreatePayload(BaseModel): message_id: UUIDStrOrEmpty + register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) From 413623598bf65e168e72e437675380c1d6ca4bd2 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 12 Dec 2025 02:50:12 +0900 Subject: [PATCH 15/28] fix --- api/libs/helper.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/api/libs/helper.py b/api/libs/helper.py index ba02fe1f5d..888a99fb5c 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast +from uuid import UUID from zoneinfo import available_timezones from flask import Response, stream_with_context @@ -119,7 +120,17 @@ def uuid_value(value: Any) -> str: raise ValueError(error) -UUIDStrOrEmpty = Annotated[str, AfterValidator(uuid_value)] +def normalize_uuid(value: str | UUID | None) -> str | None: + if not value: + return None + + try: + return uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] def alphanumeric(value: str): From d74ac9f09a9cdae618c392aec0efafec34658fd2 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 07:20:48 +0000 Subject: [PATCH 16/28] [autofix.ci] apply automated fixes --- api/controllers/service_api/app/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 56f63d8a78..dd061b23b0 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -56,7 +56,7 @@ class ChatRequestPayload(BaseModel): retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") - + @field_validator("conversation_id", mode="before") @classmethod def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: From 1d14692dd396acac6b4c8f3545f0c4a706603832 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 16 Dec 2025 16:28:38 +0900 Subject: [PATCH 17/28] missing import --- api/controllers/service_api/app/completion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index dd061b23b0..9d8431f066 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,9 +1,10 @@ import logging from typing import Any, Literal +from uuid import UUID from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services From 64a20aea655a23ac9031659823f4665b8c735f30 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 17 Dec 2025 21:06:52 +0900 Subject: [PATCH 18/28] fix --- .../datasets/rag_pipeline/rag_pipeline_workflow.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 22b5b64009..df2cc3f892 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -3,7 +3,7 @@ import logging from typing import Any, Literal, cast from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -120,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel): start_node_title: str +class RagPipelineRecommendedPluginQuery(BaseModel): + type: str = "all" + + register_schema_models( console_ns, DraftWorkflowSyncPayload, @@ -134,6 +138,7 @@ register_schema_models( NodeIdQuery, WorkflowRunQuery, DatasourceVariablesPayload, + RagPipelineRecommendedPluginQuery, ) @@ -974,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, location="args", required=False, default="all") - args = parser.parse_args() - type = args["type"] + query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) return recommended_plugins From cdd60dc1b83790e281654a217eb289b2d38dc087 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 17 Dec 2025 21:58:55 +0900 Subject: [PATCH 19/28] use type annotation --- api/controllers/console/workspace/tool_providers.py | 4 ++-- api/services/tools/api_tools_manage_service.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c462467c15..3f78d792e6 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -24,7 +24,7 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration, ApiProviderSchemaType from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required @@ -74,7 +74,7 @@ class BuiltinToolUpdatePayload(BaseModel): class ApiToolProviderBasePayload(BaseModel): credentials: dict[str, Any] - schema_type: str + schema_type: ApiProviderSchemaType schema_: str = Field(alias="schema") provider: str icon: dict[str, Any] diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 73d9095580..9c9ccd4f4e 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -113,9 +113,6 @@ class ApiToolManageService: """ create api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists From af8a137a2abf388359b9fdd641a87b3b19dfbbec Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 17 Dec 2025 22:00:29 +0900 Subject: [PATCH 20/28] nit --- api/services/tools/api_tools_manage_service.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 9c9ccd4f4e..1a8092157a 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -251,9 +251,6 @@ class ApiToolManageService: """ update api tool provider """ - if schema_type not in ApiProviderSchemaType: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists From e373dc487eccc70e92104e87371dad13491fd02e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 17 Dec 2025 22:02:08 +0900 Subject: [PATCH 21/28] nit --- api/controllers/web/conversation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index bf197a6601..541975aade 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -2,7 +2,7 @@ from typing import Literal from flask import request from flask_restx import fields, marshal_with -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -38,6 +38,13 @@ class ConversationRenamePayload(BaseModel): name: str | None = None auto_generate: bool = False + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) From 0e2eab75020744d80547837ccd77a365bc211b2f Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:04:08 +0000 Subject: [PATCH 22/28] [autofix.ci] apply automated fixes --- api/controllers/console/workspace/tool_providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 3f78d792e6..cb02277c4e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -24,7 +24,7 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration, ApiProviderSchemaType +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required From 275dff2db32f6955a9b52fc03525404e2dbad8f2 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 17 Dec 2025 22:06:12 +0900 Subject: [PATCH 23/28] Update api/libs/helper.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/libs/helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/libs/helper.py b/api/libs/helper.py index 888a99fb5c..a1d431654e 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -130,7 +130,7 @@ def normalize_uuid(value: str | UUID | None) -> str | None: raise ValueError("must be a valid UUID") from exc -UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] +UUIDStrOrEmpty = Annotated[str | None, AfterValidator(normalize_uuid)] def alphanumeric(value: str): From 723cda45b118793bde37635a79a7d3f708dfca27 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 17 Dec 2025 22:13:20 +0900 Subject: [PATCH 24/28] fix --- api/libs/helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/libs/helper.py b/api/libs/helper.py index a1d431654e..888a99fb5c 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -130,7 +130,7 @@ def normalize_uuid(value: str | UUID | None) -> str | None: raise ValueError("must be a valid UUID") from exc -UUIDStrOrEmpty = Annotated[str | None, AfterValidator(normalize_uuid)] +UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] def alphanumeric(value: str): From 16a06923cc22a949572c656290808cd28a666e80 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 18 Dec 2025 14:58:02 +0900 Subject: [PATCH 25/28] fix test --- api/core/tools/utils/parser.py | 2 +- .../tools/api_tools_manage_service.py | 6 +++-- .../tools/test_api_tools_manage_service.py | 25 ++++++------------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3486182192..584975de05 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser: @staticmethod def auto_parse_to_tool_bundle( content: str, extra_info: dict | None = None, warning: dict | None = None - ) -> tuple[list[ApiToolBundle], str]: + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ auto parse to tool bundle diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 1a8092157a..ce97a1a9d0 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -86,7 +86,9 @@ class ApiToolManageService: raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]: + def convert_schema_to_tool_bundles( + schema: str, extra_info: dict | None = None + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ convert schema to tool bundles @@ -104,7 +106,7 @@ class ApiToolManageService: provider_name: str, icon: dict, credentials: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str, custom_disclaimer: str, diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 0871467a05..26f62d62f0 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -2,7 +2,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import TypeAdapter, ValidationError +from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -298,7 +300,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -364,7 +366,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none"} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -428,19 +430,8 @@ class TestApiToolManageService: labels = ["test"] # Act & Assert: Try to create provider with invalid schema type - with pytest.raises(ValueError) as exc_info: - ApiToolManageService.create_api_tool_provider( - user_id=account.id, - tenant_id=tenant.id, - provider_name=provider_name, - icon=icon, - credentials=credentials, - schema_type=schema_type, - schema=schema, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - labels=labels, - ) + with pytest.raises(ValidationError) as exc_info: + TypeAdapter(ApiProviderSchemaType).validate_python(schema_type) assert "invalid schema type" in str(exc_info.value) @@ -464,7 +455,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {} # Missing auth_type - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -507,7 +498,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔑"} credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" From 06599f906190d482f699629ad0531bddfa051406 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 18 Dec 2025 15:03:42 +0900 Subject: [PATCH 26/28] try --- api/controllers/console/workspace/tool_providers.py | 2 +- api/services/tools/api_tools_manage_service.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index cb02277c4e..48eb866def 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -119,7 +119,7 @@ class ApiToolTestPayload(BaseModel): provider_name: str | None = None credentials: dict[str, Any] parameters: dict[str, Any] - schema_type: str + schema_type: ApiProviderSchemaType schema_: str = Field(alias="schema") diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ce97a1a9d0..00b64d1790 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -244,7 +244,7 @@ class ApiToolManageService: original_provider: str, icon: dict, credentials: dict, - schema_type: str, + _schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str | None, custom_disclaimer: str, @@ -362,7 +362,7 @@ class ApiToolManageService: tool_name: str, credentials: dict, parameters: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, ): """ From a8c344e978edfd55c0eb0b202ce5b7197879ed19 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 18 Dec 2025 15:08:54 +0900 Subject: [PATCH 27/28] fix not used bug --- api/services/tools/api_tools_manage_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 00b64d1790..d3ddea4e3b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -277,7 +277,7 @@ class ApiToolManageService: provider.icon = json.dumps(icon) provider.schema = schema provider.description = extra_info.get("description", "") - provider.schema_type_str = ApiProviderSchemaType.OPENAPI + provider.schema_type_str = schema_type provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer From 0f1df4f3393fd0ffafba08c52bee0b8eae2ae0c8 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 18 Dec 2025 15:21:28 +0900 Subject: [PATCH 28/28] assert correct msg --- .../services/tools/test_api_tools_manage_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 26f62d62f0..2ff71ea6ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -433,7 +433,7 @@ class TestApiToolManageService: with pytest.raises(ValidationError) as exc_info: TypeAdapter(ApiProviderSchemaType).validate_python(schema_type) - assert "invalid schema type" in str(exc_info.value) + assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies