diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3c95779475..3f23fd55b4 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,8 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -18,6 +19,15 @@ from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService + +class InstalledAppCreatePayload(BaseModel): + app_id: str + + +class InstalledAppUpdatePayload(BaseModel): + is_pinned: bool | None = None + + logger = logging.getLogger(__name__) @@ -105,16 +115,15 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") def post(self): - parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") - args = parser.parse_args() + payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() if recommended_app is None: raise NotFound("App not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == payload.app_id).first() if app is None: raise NotFound("App not found") @@ -124,7 +133,7 @@ class InstalledAppsListApi(Resource): installed_app = ( db.session.query(InstalledApp) - .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) .first() ) @@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource): recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args["app_id"], + app_id=payload.app_id, tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, @@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource): return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): - parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) - args = parser.parse_args() + payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {}) commit_args = False - if "is_pinned" in args: - installed_app.is_pinned = args["is_pinned"] + if payload.is_pinned is not None: + installed_app.is_pinned = payload.is_pinned commit_args = True if commit_args: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 08f29b4655..c2cb4578ea 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,7 +1,10 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from constants import HIDDEN_VALUE from controllers.console import console_ns +from controllers.common.schema import register_schema_models from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import current_account_with_tenant, login_required @@ -9,6 +12,28 @@ from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService + +class CodeBasedExtensionQuery(BaseModel): + module: str + + +class APIBasedExtensionPayload(BaseModel): + name: str = Field(description="Extension name") + api_endpoint: str = Field(description="API endpoint URL") + api_key: str = Field(description="API key for authentication") + + +class CreateAPIBasedExtensionPayload(APIBasedExtensionPayload): + pass + + +class UpdateAPIBasedExtensionPayload(APIBasedExtensionPayload): + pass + + +register_schema_models(console_ns, CreateAPIBasedExtensionPayload, UpdateAPIBasedExtensionPayload) + + api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) @@ -18,11 +43,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m class CodeBasedExtensionAPI(Resource): @console_ns.doc("get_code_based_extension") @console_ns.doc(description="Get code-based extension data by module name") - @console_ns.expect( - console_ns.parser().add_argument( - "module", type=str, required=True, location="args", help="Extension module name" - ) - ) + @console_ns.doc(params={"module": "Extension module name"}) @console_ns.response( 200, "Success", @@ -35,10 +56,9 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args") - args = parser.parse_args() + query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} + return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} @console_ns.route("/api-based-extension") @@ -56,30 +76,21 @@ class APIBasedExtensionAPI(Resource): @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") - @console_ns.expect( - console_ns.model( - "CreateAPIBasedExtensionRequest", - { - "name": fields.String(required=True, description="Extension name"), - "api_endpoint": fields.String(required=True, description="API endpoint URL"), - "api_key": fields.String(required=True, description="API key for authentication"), - }, - ) - ) + @console_ns.expect(console_ns.models[CreateAPIBasedExtensionPayload.__name__]) @console_ns.response(201, "Extension created successfully", api_based_extension_model) @setup_required @login_required @account_initialization_required @marshal_with(api_based_extension_model) def post(self): - args = console_ns.payload + payload = CreateAPIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() extension_data = APIBasedExtension( tenant_id=current_tenant_id, - name=args["name"], - api_endpoint=args["api_endpoint"], - api_key=args["api_key"], + name=payload.name, + api_endpoint=payload.api_endpoint, + api_key=payload.api_key, ) return APIBasedExtensionService.save(extension_data) @@ -104,16 +115,7 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.expect( - console_ns.model( - "UpdateAPIBasedExtensionRequest", - { - "name": fields.String(required=True, description="Extension name"), - "api_endpoint": fields.String(required=True, description="API endpoint URL"), - "api_key": fields.String(required=True, description="API key for authentication"), - }, - ) - ) + @console_ns.expect(console_ns.models[UpdateAPIBasedExtensionPayload.__name__]) @console_ns.response(200, "Extension updated successfully", api_based_extension_model) @setup_required @login_required @@ -125,13 +127,13 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) - args = console_ns.payload + payload = UpdateAPIBasedExtensionPayload.model_validate(console_ns.payload or {}) - extension_data_from_db.name = args["name"] - extension_data_from_db.api_endpoint = args["api_endpoint"] + extension_data_from_db.name = payload.name + extension_data_from_db.api_endpoint = payload.api_endpoint - if args["api_key"] != HIDDEN_VALUE: - extension_data_from_db.api_key = args["api_key"] + if payload.api_key != HIDDEN_VALUE: + extension_data_from_db.api_key = payload.api_key return APIBasedExtensionService.save(extension_data_from_db) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 17cfc3ff4b..c545348cde 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,31 +1,39 @@ from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required -from models.model import Tag from services.tag_service import TagService - -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -parser_tags = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=_validate_name, - ) - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") +class TagBasePayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +class TagBindingPayload(BaseModel): + tag_ids: list[str] + target_id: str + type: Literal["knowledge", "app"] | None = None + + +class TagBindingRemovePayload(BaseModel): + tag_id: str + target_id: str + type: Literal["knowledge", "app"] | None = None + + +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, ) @@ -43,7 +51,7 @@ class TagListApi(Resource): return tags, 200 - @console_ns.expect(parser_tags) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -53,22 +61,17 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tags.parse_args() - tag = TagService.save_tags(args) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 -parser_tag_id = reqparse.RequestParser().add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name -) - - @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(parser_tag_id) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,8 +82,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tag_id.parse_args() - tag = TagService.update_tags(args, tag_id) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -100,17 +103,9 @@ class TagUpdateDeleteApi(Resource): return 204 -parser_create = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -120,23 +115,15 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_create.parse_args() - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) return {"result": "success"}, 200 -parser_remove = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): - @console_ns.expect(parser_remove) + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -146,7 +133,7 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_remove.parse_args() - TagService.delete_tag_binding(args) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 9bf393ea2e..588bff494b 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,5 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.console import console_ns @@ -9,6 +10,20 @@ from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class LoadBalancingCredentialPayload(BaseModel): + model: str + model_type: ModelType + credentials: dict + + +console_ns.schema_model( + LoadBalancingCredentialPayload.__name__, + LoadBalancingCredentialPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route( "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate" @@ -24,20 +39,7 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {}) # validate model load balancing credentials model_load_balancing_service = ModelLoadBalancingService() @@ -49,9 +51,9 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=payload.model, + model_type=payload.model_type, + credentials=payload.credentials, ) except CredentialsValidateFailedError as ex: result = False @@ -79,20 +81,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {}) # validate model load balancing config credentials model_load_balancing_service = ModelLoadBalancingService() @@ -104,9 +93,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=payload.model, + model_type=payload.model_type, + credentials=payload.credentials, config_id=config_id, ) except CredentialsValidateFailedError as ex: diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2c54aa5a20..d9c1220b78 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,15 +1,15 @@ import io +from typing import Any, Literal from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_restx import ( - Resource, - reqparse, -) +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -25,7 +25,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db -from libs.helper import StrLen, alphanumeric, uuid_value +from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -48,24 +48,216 @@ def is_valid_url(url: str) -> bool: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] except (ValueError, TypeError): - # ValueError: Invalid URL format - # TypeError: url is not a string return False -parser_tool = reqparse.RequestParser().add_argument( - "type", - type=str, - choices=["builtin", "model", "api", "workflow", "mcp"], - required=False, - nullable=True, - location="args", +class ToolProviderListQuery(BaseModel): + type: Literal["builtin", "model", "api", "workflow", "mcp"] | None = None + + +class BuiltinToolCredentialDeletePayload(BaseModel): + credential_id: str + + +class BuiltinToolAddPayload(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + type: CredentialType + + +class BuiltinToolUpdatePayload(BaseModel): + credential_id: str + credentials: dict[str, Any] | None = None + name: str | None = Field(default=None, max_length=30) + + +class ApiToolProviderBasePayload(BaseModel): + credentials: dict[str, Any] + schema_type: str + schema: str + provider: str + icon: dict[str, Any] + privacy_policy: str | None = None + labels: list[str] | None = None + custom_disclaimer: str | None = None + + +class ApiToolProviderAddPayload(ApiToolProviderBasePayload): + pass + + +class ApiToolProviderUpdatePayload(ApiToolProviderBasePayload): + original_provider: str + + +class UrlQuery(BaseModel): + url: str + + @field_validator("url") + @classmethod + def validate_url(cls, value: str) -> str: + if not is_valid_url(value): + raise ValueError("Invalid URL") + return value + + +class ProviderQuery(BaseModel): + provider: str + + +class ApiToolProviderDeletePayload(BaseModel): + provider: str + + +class ApiToolSchemaPayload(BaseModel): + schema: str + + +class ApiToolTestPayload(BaseModel): + tool_name: str + provider_name: str | None = None + credentials: dict[str, Any] + parameters: dict[str, Any] + schema_type: str + schema: str + + +class WorkflowToolBasePayload(BaseModel): + name: str + label: str + description: str + icon: dict[str, Any] + parameters: list[dict[str, Any]] + privacy_policy: str | None = "" + labels: list[str] | None = None + + @field_validator("name") + @classmethod + def validate_name(cls, value: str) -> str: + return alphanumeric(value) + + +class WorkflowToolCreatePayload(WorkflowToolBasePayload): + workflow_app_id: str + + @field_validator("workflow_app_id") + @classmethod + def validate_workflow_app_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolUpdatePayload(WorkflowToolBasePayload): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolDeletePayload(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolGetQuery(BaseModel): + workflow_tool_id: str | None = None + workflow_app_id: str | None = None + + @field_validator("workflow_tool_id", "workflow_app_id") + @classmethod + def validate_ids(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + @model_validator(mode="after") + def ensure_one(self) -> "WorkflowToolGetQuery": + if not self.workflow_tool_id and not self.workflow_app_id: + raise ValueError("workflow_tool_id or workflow_app_id is required") + return self + + +class WorkflowToolListQuery(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class BuiltinProviderDefaultCredentialPayload(BaseModel): + id: str + + +class ToolOAuthCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = True + + +class MCPProviderBasePayload(BaseModel): + server_url: str + name: str + icon: str + icon_type: str + icon_background: str | None = "" + server_identifier: str + configuration: dict[str, Any] | None = Field(default_factory=dict) + headers: dict[str, Any] | None = Field(default_factory=dict) + authentication: dict[str, Any] | None = Field(default_factory=dict) + + +class MCPProviderCreatePayload(MCPProviderBasePayload): + pass + + +class MCPProviderUpdatePayload(MCPProviderBasePayload): + provider_id: str + + +class MCPProviderDeletePayload(BaseModel): + provider_id: str + + +class MCPAuthPayload(BaseModel): + provider_id: str + authorization_code: str | None = None + + +class MCPCallbackQuery(BaseModel): + code: str + state: str + + +register_schema_models( + console_ns, + BuiltinToolCredentialDeletePayload, + BuiltinToolAddPayload, + BuiltinToolUpdatePayload, + ApiToolProviderAddPayload, + ApiToolProviderUpdatePayload, + ApiToolProviderDeletePayload, + ApiToolSchemaPayload, + ApiToolTestPayload, + WorkflowToolCreatePayload, + WorkflowToolUpdatePayload, + WorkflowToolDeletePayload, + BuiltinProviderDefaultCredentialPayload, + ToolOAuthCustomClientPayload, + MCPProviderCreatePayload, + MCPProviderUpdatePayload, + MCPProviderDeletePayload, + MCPAuthPayload, ) @console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): - @console_ns.expect(parser_tool) @setup_required @login_required @account_initialization_required @@ -74,9 +266,10 @@ class ToolProviderListApi(Resource): user_id = user.id - args = parser_tool.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -106,14 +299,9 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) -parser_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[BuiltinToolCredentialDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -121,26 +309,18 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): _, tenant_id = current_account_with_tenant() - args = parser_delete.parse_args() + payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( tenant_id, provider, - args["credential_id"], + payload.credential_id, ) -parser_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - .add_argument("type", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): - @console_ns.expect(parser_add) + @console_ns.expect(console_ns.models[BuiltinToolAddPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -149,32 +329,21 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - args = parser_add.parse_args() - - if args["type"] not in CredentialType.values(): - raise ValueError(f"Invalid credential type: {args['type']}") + payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credentials=args["credentials"], - name=args["name"], - api_type=CredentialType.of(args["type"]), + credentials=payload.credentials, + name=payload.name, + api_type=CredentialType.of(payload.type), ) -parser_update = ( - reqparse.RequestParser() - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[BuiltinToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -183,15 +352,15 @@ class ToolBuiltinProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_update.parse_args() + payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credential_id=args["credential_id"], - credentials=args.get("credentials", None), - name=args.get("name", ""), + credential_id=payload.credential_id, + credentials=payload.credentials, + name=payload.name or "", ) return result @@ -221,22 +390,9 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -parser_api_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): - @console_ns.expect(parser_api_add) + @console_ns.expect(console_ns.models[ApiToolProviderAddPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -246,28 +402,24 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - args = parser_api_add.parse_args() + payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args["provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args.get("privacy_policy", ""), - args.get("custom_disclaimer", ""), - args.get("labels", []), + payload.provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema, + payload.privacy_policy or "", + payload.custom_disclaimer or "", + payload.labels or [], ) -parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): - @console_ns.expect(parser_remote) @setup_required @login_required @account_initialization_required @@ -276,23 +428,18 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - args = parser_remote.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, - args["url"], + query.url, ) -parser_tools = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): - @console_ns.expect(parser_tools) @setup_required @login_required @account_initialization_required @@ -301,34 +448,21 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - args = parser_tools.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, - args["provider"], + query.provider, ) ) -parser_api_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("original_provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") - .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): - @console_ns.expect(parser_api_update) + @console_ns.expect(console_ns.models[ApiToolProviderUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -338,31 +472,26 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - args = parser_api_update.parse_args() + payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args["provider"], - args["original_provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args["privacy_policy"], - args["custom_disclaimer"], - args.get("labels", []), + payload.provider, + payload.original_provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema, + payload.privacy_policy, + payload.custom_disclaimer, + payload.labels or [], ) -parser_api_delete = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): - @console_ns.expect(parser_api_delete) + @console_ns.expect(console_ns.models[ApiToolProviderDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -372,21 +501,17 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - args = parser_api_delete.parse_args() + payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args["provider"], + payload.provider, ) -parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): - @console_ns.expect(parser_get) @setup_required @login_required @account_initialization_required @@ -395,12 +520,13 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - args = parser_get.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args["provider"], + query.provider, ) @@ -419,72 +545,43 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) -parser_schema = reqparse.RequestParser().add_argument( - "schema", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): - @console_ns.expect(parser_schema) + @console_ns.expect(console_ns.models[ApiToolSchemaPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_schema.parse_args() + payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=args["schema"], + schema=payload.schema, ) -parser_pre = ( - reqparse.RequestParser() - .add_argument("tool_name", type=str, required=True, nullable=False, location="json") - .add_argument("provider_name", type=str, required=False, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): - @console_ns.expect(parser_pre) + @console_ns.expect(console_ns.models[ApiToolTestPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_pre.parse_args() + payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, - args["provider_name"] or "", - args["tool_name"], - args["credentials"], - args["parameters"], - args["schema_type"], - args["schema"], + payload.provider_name or "", + payload.tool_name, + payload.credentials, + payload.parameters, + payload.schema_type, + payload.schema, ) -parser_create = ( - reqparse.RequestParser() - .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[WorkflowToolCreatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -494,38 +591,25 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - args = parser_create.parse_args() + payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, - workflow_app_id=args["workflow_app_id"], - name=args["name"], - label=args["label"], - icon=args["icon"], - description=args["description"], - parameters=args["parameters"], - privacy_policy=args["privacy_policy"], - labels=args["labels"], + workflow_app_id=payload.workflow_app_id, + name=payload.name, + label=payload.label, + icon=payload.icon, + description=payload.description, + parameters=payload.parameters, + privacy_policy=payload.privacy_policy or "", + labels=payload.labels, ) -parser_workflow_update = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): - @console_ns.expect(parser_workflow_update) + @console_ns.expect(console_ns.models[WorkflowToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -534,33 +618,25 @@ class ToolWorkflowProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_workflow_update.parse_args() - - if not args["workflow_tool_id"]: - raise ValueError("incorrect workflow_tool_id") + payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + payload.workflow_tool_id, + payload.name, + payload.label, + payload.icon, + payload.description, + payload.parameters, + payload.privacy_policy or "", + payload.labels or [], ) -parser_workflow_delete = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): - @console_ns.expect(parser_workflow_delete) + @console_ns.expect(console_ns.models[WorkflowToolDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -570,25 +646,17 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - args = parser_workflow_delete.parse_args() + payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], + payload.workflow_tool_id, ) -parser_wf_get = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): - @console_ns.expect(parser_wf_get) @setup_required @login_required @account_initialization_required @@ -597,19 +665,20 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - args = parser_wf_get.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = WorkflowToolGetQuery.model_validate(raw_args) - if args.get("workflow_tool_id"): + if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) - elif args.get("workflow_app_id"): + elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args["workflow_app_id"], + query.workflow_app_id, ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") @@ -617,14 +686,8 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) -parser_wf_tools = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): - @console_ns.expect(parser_wf_tools) @setup_required @login_required @account_initialization_required @@ -633,13 +696,14 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - args = parser_wf_tools.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) ) @@ -806,49 +870,39 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_default_cred = reqparse.RequestParser().add_argument( - "id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): - @console_ns.expect(parser_default_cred) + @console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - args = parser_default_cred.parse_args() + payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id ) -parser_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): - @console_ns.expect(parser_custom) + @console_ns.expect(console_ns.models[ToolOAuthCustomClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - args = parser_custom.parse_args() + payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, - client_params=args.get("client_params", {}), - enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + client_params=payload.client_params or {}, + enable_oauth_custom_client=payload.enable_oauth_custom_client + if payload.enable_oauth_custom_client is not None + else True, ) @setup_required @@ -900,49 +954,19 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) -parser_mcp = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_put = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_delete = reqparse.RequestParser().add_argument( - "provider_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): - @console_ns.expect(parser_mcp) + @console_ns.expect(console_ns.models[MCPProviderCreatePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_mcp.parse_args() + payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) user, tenant_id = current_account_with_tenant() # Parse and validate models - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None # Create provider with Session(db.engine) as session, session.begin(): @@ -950,26 +974,26 @@ class ToolProviderMCPApi(Resource): result = service.create_provider( tenant_id=tenant_id, user_id=user.id, - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background or "", + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, ) return jsonable_encoder(result) - @console_ns.expect(parser_mcp_put) + @console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self): - args = parser_mcp_put.parse_args() - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None _, current_tenant_id = current_account_with_tenant() # Step 1: Validate server URL change if needed (includes URL format validation and network operation) @@ -977,7 +1001,9 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) validation_result = service.validate_server_url_change( - tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + tenant_id=current_tenant_id, + provider_id=payload.provider_id, + new_server_url=payload.server_url, ) # No need to check for errors here, exceptions will be raised directly @@ -987,50 +1013,43 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + provider_id=payload.provider_id, + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, validation_result=validation_result, ) return {"result": "success"} - @console_ns.expect(parser_mcp_delete) + @console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def delete(self): - args = parser_mcp_delete.parse_args() + payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) return {"result": "success"} -parser_auth = ( - reqparse.RequestParser() - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): - @console_ns.expect(parser_auth) + @console_ns.expect(console_ns.models[MCPAuthPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_auth.parse_args() - provider_id = args["provider_id"] + payload = MCPAuthPayload.model_validate(console_ns.payload or {}) + provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): @@ -1068,7 +1087,7 @@ class ToolMCPAuthApi(Resource): # Pass the extracted OAuth metadata hints to auth() auth_result = auth( provider_entity, - args.get("authorization_code"), + payload.authorization_code, resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) @@ -1133,20 +1152,13 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -parser_cb = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, nullable=False, location="args") - .add_argument("state", type=str, required=True, nullable=False, location="args") -) - - @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): - @console_ns.expect(parser_cb) def get(self): - args = parser_cb.parse_args() - state_key = args["state"] - authorization_code = args["code"] + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = MCPCallbackQuery.model_validate(raw_args) + state_key = query.state + authorization_code = query.code # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 268473d6d1..30fb954755 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -1,11 +1,14 @@ import logging +from typing import Any from flask import make_response, redirect, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.web.error import NotFoundError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType @@ -32,6 +35,35 @@ from ..wraps import ( logger = logging.getLogger(__name__) +class TriggerSubscriptionBuilderCreatePayload(BaseModel): + credential_type: str | None = None + + +class TriggerSubscriptionBuilderVerifyPayload(BaseModel): + credentials: dict[str, Any] | None = None + + +class TriggerSubscriptionBuilderUpdatePayload(BaseModel): + name: str | None = None + parameters: dict[str, Any] | None = None + properties: dict[str, Any] | None = None + credentials: dict[str, Any] | None = None + + +class TriggerOAuthClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enabled: bool | None = None + + +register_schema_models( + console_ns, + TriggerSubscriptionBuilderCreatePayload, + TriggerSubscriptionBuilderVerifyPayload, + TriggerSubscriptionBuilderUpdatePayload, + TriggerOAuthClientPayload, +) + + @console_ns.route("/workspaces/current/trigger-provider//icon") class TriggerProviderIconApi(Resource): @setup_required @@ -97,16 +129,11 @@ class TriggerSubscriptionListApi(Resource): raise -parser = reqparse.RequestParser().add_argument( - "credential_type", type=str, required=False, nullable=True, location="json" -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/create", ) class TriggerSubscriptionBuilderCreateApi(Resource): - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -116,10 +143,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser.parse_args() + payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {}) try: - credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value) + credential_type = CredentialType.of(payload.credential_type or CredentialType.UNAUTHORIZED.value) subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, @@ -147,18 +174,11 @@ class TriggerSubscriptionBuilderGetApi(Resource): ) -parser_api = ( - reqparse.RequestParser() - # The credentials of the subscription builder - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/verify/", ) class TriggerSubscriptionBuilderVerifyApi(Resource): - @console_ns.expect(parser_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -168,7 +188,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser_api.parse_args() + payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_verify to prevent race conditions @@ -178,7 +198,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - credentials=args.get("credentials", None), + credentials=payload.credentials, ), ) except Exception as e: @@ -186,24 +206,11 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): raise ValueError(str(e)) from e -parser_update_api = ( - reqparse.RequestParser() - # The name of the subscription builder - .add_argument("name", type=str, required=False, nullable=True, location="json") - # The parameters of the subscription builder - .add_argument("parameters", type=dict, required=False, nullable=True, location="json") - # The properties of the subscription builder - .add_argument("properties", type=dict, required=False, nullable=True, location="json") - # The credentials of the subscription builder - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/update/", ) class TriggerSubscriptionBuilderUpdateApi(Resource): - @console_ns.expect(parser_update_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -214,7 +221,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): assert isinstance(user, Account) assert user.current_tenant_id is not None - args = parser_update_api.parse_args() + payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: return jsonable_encoder( TriggerSubscriptionBuilderService.update_trigger_subscription_builder( @@ -222,10 +229,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), - credentials=args.get("credentials", None), + name=payload.name, + parameters=payload.parameters, + properties=payload.properties, + credentials=payload.credentials, ), ) ) @@ -260,7 +267,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource): "/workspaces/current/trigger-provider//subscriptions/builder/build/", ) class TriggerSubscriptionBuilderBuildApi(Resource): - @console_ns.expect(parser_update_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -269,7 +276,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): """Build a subscription instance for a trigger provider""" user = current_user assert user.current_tenant_id is not None - args = parser_update_api.parse_args() + payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_build to prevent race conditions TriggerSubscriptionBuilderService.update_and_build_builder( @@ -278,9 +285,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), + name=payload.name, + parameters=payload.parameters, + properties=payload.properties, ), ) return 200 @@ -474,13 +481,6 @@ class TriggerOAuthCallbackApi(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_oauth_client = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enabled", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/trigger-provider//oauth/client") class TriggerOAuthClientManageApi(Resource): @setup_required @@ -528,7 +528,7 @@ class TriggerOAuthClientManageApi(Resource): logger.exception("Error getting OAuth client", exc_info=e) raise - @console_ns.expect(parser_oauth_client) + @console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -538,15 +538,15 @@ class TriggerOAuthClientManageApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser_oauth_client.parse_args() + payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {}) try: provider_id = TriggerProviderID(provider) return TriggerProviderService.save_custom_oauth_client_params( tenant_id=user.current_tenant_id, provider_id=provider_id, - client_params=args.get("client_params"), - enabled=args.get("enabled"), + client_params=payload.client_params, + enabled=payload.enabled, ) except ValueError as e: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 60193f5f15..9721003eab 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, ConfigDict, Field from werkzeug.exceptions import Unauthorized from constants import HEADER_NAME_APP_CODE @@ -21,6 +22,13 @@ from services.webapp_auth_service import WebAppAuthService logger = logging.getLogger(__name__) +class AppAccessModeQuery(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + app_id: str | None = Field(default=None, alias="appId", description="Application ID") + app_code: str | None = Field(default=None, alias="appCode", description="Application code") + + @web_ns.route("/parameters") class AppParameterApi(WebApiResource): """Resource for app variables.""" @@ -82,12 +90,7 @@ class AppMeta(WebApiResource): class AppAccessMode(Resource): @web_ns.doc("Get App Access Mode") @web_ns.doc(description="Retrieve the access mode for a web application (public or restricted).") - @web_ns.doc( - params={ - "appId": {"description": "Application ID", "type": "string", "required": False}, - "appCode": {"description": "Application code", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[AppAccessModeQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -96,21 +99,16 @@ class AppAccessMode(Resource): } ) def get(self): - parser = ( - reqparse.RequestParser() - .add_argument("appId", type=str, required=False, location="args") - .add_argument("appCode", type=str, required=False, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + args = AppAccessModeQuery.model_validate(raw_args) features = FeatureService.get_system_features() if not features.webapp_auth.enabled: return {"accessMode": "public"} - app_id = args.get("appId") - if args.get("appCode"): - app_code = args["appCode"] - app_id = AppService.get_app_id_by_code(app_code) + app_id = args.app_id + if args.app_code: + app_id = AppService.get_app_id_by_code(args.app_code) if not app_id: raise ValueError("appId or appCode must be provided") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b9fef48c4d..30bd07847e 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import fields, marshal_with, reqparse +from flask_restx import fields, marshal_with +from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError import services @@ -17,9 +18,11 @@ from controllers.web.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) +from ..common import register_schema_models from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService from services.errors.audio import ( @@ -29,6 +32,22 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) + +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + +register_schema_models(web_ns, TextToAudioPayload) + logger = logging.getLogger(__name__) @@ -88,6 +107,7 @@ class AudioApi(WebApiResource): @web_ns.route("/text-to-audio") class TextApi(WebApiResource): + @web_ns.expect(web_ns.models[TextToAudioPayload.__name__]) @web_ns.doc("Text to Audio") @web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc( @@ -102,18 +122,11 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(web_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e8a4698375..830d2a16b5 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging +from typing import Any, Literal -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services @@ -34,25 +35,41 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the completion") + query: str = Field(description="Query text for completion") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + retriever_from: str | None = Field(default="web_app", description="Source of retriever") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the chat") + query: str = Field(description="User query/message") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + retriever_from: str | None = Field(default="web_app", description="Source of retriever") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): @web_ns.doc("Create Completion Message") @web_ns.doc(description="Create a completion message for text generation applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, - "query": {"description": "Query text for completion", "type": "string", "required": False}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -67,18 +84,10 @@ class CompletionApi(WebApiResource): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = CompletionMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: @@ -142,22 +151,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): @web_ns.doc("Create Chat Message") @web_ns.doc(description="Create a chat message for conversational applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, - "query": {"description": "User query/message", "type": "string", "required": True}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False}, - "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False}, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[ChatMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -173,20 +167,10 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = ChatMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 86e19423e5..3281e0b001 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,8 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from typing import Literal + +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -16,6 +19,25 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -60,25 +82,8 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: @@ -86,11 +91,11 @@ class ConversationListApi(WebApiResource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -163,15 +168,10 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) + return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index b9e391e049..6f22171c67 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -2,7 +2,8 @@ import base64 import secrets from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,12 +19,34 @@ from controllers.console.error import EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.web import web_ns from extensions.ext_database import db -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password from models import Account from services.account_service import AccountService +class ForgotPasswordSendPayload(BaseModel): + email: EmailStr + language: str | None = None + + +class ForgotPasswordCheckPayload(BaseModel): + email: EmailStr + code: str + token: str = Field(min_length=1) + + +class ForgotPasswordResetPayload(BaseModel): + token: str = Field(min_length=1) + new_password: str + password_confirm: str + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + @web_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): @only_edition_enterprise @@ -40,29 +63,24 @@ class ForgotPasswordSendEmailApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {}) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if payload.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none() token = None if account is None: raise AuthenticationFailedError() else: - token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language) return {"result": "success", "data": token} @@ -78,40 +96,34 @@ class ForgotPasswordCheckApi(Resource): responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"} ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {}) - user_email = args["email"] + user_email = payload.email - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() - token_data = AccountService.get_reset_password_data(args["token"]) + token_data = AccountService.get_reset_password_data(payload.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args["email"]) + if payload.code != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(payload.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(payload.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args["code"], additional_data={"phase": "reset"} + user_email, code=payload.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) + AccountService.reset_forgot_password_error_rate_limit(payload.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -131,20 +143,14 @@ class ForgotPasswordResetApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {}) # Validate passwords match - if args["new_password"] != args["password_confirm"]: + if payload.new_password != payload.password_confirm: raise PasswordMismatchError() # Validate token and get reset data - reset_data = AccountService.get_reset_password_data(args["token"]) + reset_data = AccountService.get_reset_password_data(payload.token) if not reset_data: raise InvalidTokenError() # Must use token in reset phase @@ -152,11 +158,11 @@ class ForgotPasswordResetApi(Resource): raise InvalidTokenError() # Revoke token to prevent reuse - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(payload.token) # Generate secure salt and hash password salt = secrets.token_bytes(16) - password_hashed = hash_password(args["new_password"], salt) + password_hashed = hash_password(payload.new_password, salt) email = reset_data.get("email", "") diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 538d0c44be..861ca27fc9 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,6 +1,7 @@ from flask import make_response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource from jwt import InvalidTokenError +from pydantic import BaseModel, Field, field_validator import services from configs import dify_config @@ -13,7 +14,7 @@ from controllers.console.error import AccountBannedError from controllers.console.wraps import only_edition_enterprise, setup_required from controllers.web import web_ns from controllers.web.wraps import decode_jwt_token -from libs.helper import email +from libs.helper import EmailStr from libs.passport import PassportService from libs.password import valid_password from libs.token import ( @@ -25,6 +26,27 @@ from services.app_service import AppService from services.webapp_auth_service import WebAppAuthService +class LoginPayload(BaseModel): + email: EmailStr + password: str + + @field_validator("password") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +class EmailCodeLoginSendPayload(BaseModel): + email: EmailStr + language: str | None = None + + +class EmailCodeLoginVerifyPayload(BaseModel): + email: EmailStr + code: str + token: str = Field(min_length=1) + + @web_ns.route("/login") class LoginApi(Resource): """Resource for web app email/password login.""" @@ -44,15 +66,10 @@ class LoginApi(Resource): ) def post(self): """Authenticate user and login.""" - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("password", type=valid_password, required=True, location="json") - ) - args = parser.parse_args() + payload = LoginPayload.model_validate(web_ns.payload or {}) try: - account = WebAppAuthService.authenticate(args["email"], args["password"]) + account = WebAppAuthService.authenticate(payload.email, payload.password) except services.errors.account.AccountLoginError: raise AccountBannedError() except services.errors.account.AccountPasswordError: @@ -139,6 +156,7 @@ class EmailCodeLoginSendEmailApi(Resource): @only_edition_enterprise @web_ns.doc("send_email_code_login") @web_ns.doc(description="Send email verification code for login") + @web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__]) @web_ns.doc( responses={ 200: "Email code sent successfully", @@ -147,19 +165,14 @@ class EmailCodeLoginSendEmailApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {}) - if args["language"] is not None and args["language"] == "zh-Hans": + if payload.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" - account = WebAppAuthService.get_user_through_email(args["email"]) + account = WebAppAuthService.get_user_through_email(payload.email) if account is None: raise AuthenticationFailedError() else: @@ -173,6 +186,7 @@ class EmailCodeLoginApi(Resource): @only_edition_enterprise @web_ns.doc("verify_email_code_login") @web_ns.doc(description="Verify email code and complete login") + @web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__]) @web_ns.doc( responses={ 200: "Email code verified and login successful", @@ -182,33 +196,27 @@ class EmailCodeLoginApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, location="json") - ) - args = parser.parse_args() + payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {}) - user_email = args["email"] + user_email = payload.email - token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + token_data = WebAppAuthService.get_email_code_login_data(payload.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + if token_data["email"] != payload.email: raise InvalidEmailError() - if token_data["code"] != args["code"]: + if token_data["code"] != payload.code: raise EmailCodeError() - WebAppAuthService.revoke_email_code_login_token(args["token"]) + WebAppAuthService.revoke_email_code_login_token(payload.token) account = WebAppAuthService.get_user_through_email(user_email) if not account: raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(payload.email) response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 9f9aa4838c..fdd9352972 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,9 @@ import logging +from typing import Literal -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound from controllers.web import web_ns @@ -38,6 +40,30 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +class MessageListQuery(BaseModel): + conversation_id: str = Field(description="Conversation UUID") + first_id: str | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") + + @field_validator("conversation_id", "first_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") + + +class MessageMoreLikeThisQuery(BaseModel): + response_mode: Literal["blocking", "streaming"] = Field( + description="Response mode", + ) + + @web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { @@ -65,18 +91,7 @@ class MessageListApi(WebApiResource): @web_ns.doc("Get Message List") @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") - @web_ns.doc( - params={ - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, - "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, - "limit": { - "description": "Number of messages to return (1-100)", - "type": "integer", - "required": False, - "default": 20, - }, - } - ) + @web_ns.expect(web_ns.models[MessageListQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -93,17 +108,12 @@ class MessageListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args") - .add_argument("first_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = MessageListQuery.model_validate(raw_args) try: return MessageService.pagination_by_first_id( - app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, end_user, query.conversation_id, query.first_id, query.limit ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -120,17 +130,7 @@ class MessageFeedbackApi(WebApiResource): @web_ns.doc("Create Message Feedback") @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) - @web_ns.doc( - params={ - "rating": { - "description": "Feedback rating", - "type": "string", - "enum": ["like", "dislike"], - "required": False, - }, - "content": {"description": "Feedback content/comment", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[MessageFeedbackPayload.__name__]) @web_ns.doc( responses={ 200: "Feedback submitted successfully", @@ -145,20 +145,15 @@ class MessageFeedbackApi(WebApiResource): def post(self, app_model, end_user, message_id): message_id = str(message_id) - parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - .add_argument("content", type=str, location="json", default=None) - ) - args = parser.parse_args() + payload = MessageFeedbackPayload.model_validate(web_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=end_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -170,17 +165,7 @@ class MessageFeedbackApi(WebApiResource): class MessageMoreLikeThisApi(WebApiResource): @web_ns.doc("Generate More Like This") @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") - @web_ns.doc( - params={ - "message_id": {"description": "Message UUID", "type": "string", "required": True}, - "response_mode": { - "description": "Response mode", - "type": "string", - "enum": ["blocking", "streaming"], - "required": True, - }, - } - ) + @web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -197,12 +182,10 @@ class MessageMoreLikeThisApi(WebApiResource): message_id = str(message_id) - parser = reqparse.RequestParser().add_argument( - "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = MessageMoreLikeThisQuery.model_validate(raw_args) - streaming = args["response_mode"] == "streaming" + streaming = query.response_mode == "streaming" try: response = AppGenerateService.generate_more_like_this( diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index dac4b3da94..fed8df40b0 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,8 @@ import urllib.parse import httpx -from flask_restx import marshal_with, reqparse +from flask_restx import marshal_with +from pydantic import BaseModel, Field, HttpUrl import services from controllers.common import helpers @@ -19,6 +20,10 @@ from fields.file_fields import build_file_with_signed_url_model, build_remote_fi from services.file_service import FileService +class RemoteFileUploadPayload(BaseModel): + url: HttpUrl = Field(description="Remote file URL") + + @web_ns.route("/remote-files/") class RemoteFileInfoApi(WebApiResource): @web_ns.doc("get_remote_file_info") @@ -97,10 +102,8 @@ class RemoteFileUploadApi(WebApiResource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") - args = parser.parse_args() - - url = args["url"] + payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {}) + url = str(payload.url) try: resp = ssrf_proxy.head(url=url) diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 865f3610a7..f407507883 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,6 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import NotFound from controllers.web import web_ns @@ -23,6 +24,27 @@ message_fields = { } +class SavedMessageListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class SavedMessageCreatePayload(BaseModel): + message_id: str + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str) -> str: + return uuid_value(value) + + @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { @@ -63,14 +85,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + query = SavedMessageListQuery.model_validate(raw_args) - return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) @web_ns.doc("Save Message") @web_ns.doc(description="Save a specific message for later reference.") @@ -94,11 +112,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 3cbb07a296..136f93c574 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,7 @@ import logging +from typing import Any -from flask_restx import reqparse +from pydantic import BaseModel from werkzeug.exceptions import InternalServerError from controllers.web import web_ns @@ -27,6 +28,12 @@ from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError + +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + + logger = logging.getLogger(__name__) @@ -58,12 +65,8 @@ class WorkflowRunApi(WebApiResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload = WorkflowRunPayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = AppGenerateService.generate(