From 247069c7e96fa1763fc9dd9da001bc5683c73a64 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 30 Nov 2025 16:09:42 +0900 Subject: [PATCH] refactor: port reqparse to Pydantic model (#28913) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../console/app/advanced_prompt_template.py | 27 +- api/controllers/console/app/app.py | 378 ++++++++--------- api/controllers/console/app/completion.py | 99 +++-- api/controllers/console/app/conversation.py | 173 ++++---- .../console/app/conversation_variables.py | 30 +- api/controllers/console/app/generator.py | 219 +++++----- api/controllers/console/app/message.py | 156 ++++--- api/controllers/console/app/statistic.py | 91 +++-- api/controllers/console/app/workflow.py | 386 +++++++----------- .../console/app/workflow_app_log.py | 102 +++-- api/controllers/console/app/workflow_run.py | 140 +++---- .../console/app/workflow_statistic.py | 78 ++-- api/controllers/console/workspace/account.py | 76 +--- api/controllers/console/workspace/endpoint.py | 184 ++++----- api/controllers/console/workspace/members.py | 29 +- .../console/workspace/model_providers.py | 46 +-- api/controllers/console/workspace/models.py | 50 +-- api/controllers/console/workspace/plugin.py | 95 +---- .../console/workspace/workspace.py | 23 +- 19 files changed, 1013 insertions(+), 1369 deletions(-) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index 0ca163d2a5..3bd61feb44 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,16 +1,23 @@ -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService -parser = ( - reqparse.RequestParser() - .add_argument("app_mode", type=str, required=True, location="args", help="Application mode") - .add_argument("model_mode", type=str, required=True, location="args", help="Model mode") - .add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context") - .add_argument("model_name", type=str, required=True, location="args", help="Model name") + +class AdvancedPromptTemplateQuery(BaseModel): + app_mode: str = Field(..., description="Application mode") + model_mode: str = Field(..., description="Model mode") + has_context: str = Field(default="true", description="Whether has context") + model_name: str = Field(..., description="Model name") + + +console_ns.schema_model( + AdvancedPromptTemplateQuery.__name__, + AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"), ) @@ -18,7 +25,7 @@ parser = ( class AdvancedPromptTemplateList(Resource): @console_ns.doc("get_advanced_prompt_templates") @console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration") - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__]) @console_ns.response( 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) ) @@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource): @login_required @account_initialization_required def get(self): - args = parser.parse_args() + args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return AdvancedPromptTemplateService.get_prompt(args) + return AdvancedPromptTemplateService.get_prompt(args.model_dump()) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index e6687de03e..d6adacd84d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,9 +1,12 @@ import uuid +from typing import Literal -from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import BadRequest, abort +from werkzeug.exceptions import BadRequest from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -36,6 +39,130 @@ from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AppListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") + limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)") + mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field( + default="all", description="App mode filter" + ) + name: str | None = Field(default=None, description="Filter by app name") + tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs") + is_created_by_me: bool | None = Field(default=None, description="Filter by creator") + + @field_validator("tag_ids", mode="before") + @classmethod + def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None: + if not value: + return None + + if isinstance(value, str): + items = [item.strip() for item in value.split(",") if item.strip()] + elif isinstance(value, list): + items = [str(item).strip() for item in value if item and str(item).strip()] + else: + raise TypeError("Unsupported tag_ids type.") + + if not items: + return None + + try: + return [str(uuid.UUID(item)) for item in items] + except ValueError as exc: + raise ValueError("Invalid UUID format in tag_ids.") from exc + + +class CreateAppPayload(BaseModel): + name: str = Field(..., min_length=1, description="App name") + description: str | None = Field(default=None, description="App description (max 400 chars)") + mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") + icon_type: str | None = Field(default=None, description="Icon type") + icon: str | None = Field(default=None, description="Icon") + icon_background: str | None = Field(default=None, description="Icon background color") + + @field_validator("description") + @classmethod + def validate_description(cls, value: str | None) -> str | None: + if value is None: + return value + return validate_description_length(value) + + +class UpdateAppPayload(BaseModel): + name: str = Field(..., min_length=1, description="App name") + description: str | None = Field(default=None, description="App description (max 400 chars)") + icon_type: str | None = Field(default=None, description="Icon type") + icon: str | None = Field(default=None, description="Icon") + icon_background: str | None = Field(default=None, description="Icon background color") + use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") + max_active_requests: int | None = Field(default=None, description="Maximum active requests") + + @field_validator("description") + @classmethod + def validate_description(cls, value: str | None) -> str | None: + if value is None: + return value + return validate_description_length(value) + + +class CopyAppPayload(BaseModel): + name: str | None = Field(default=None, description="Name for the copied app") + description: str | None = Field(default=None, description="Description for the copied app") + icon_type: str | None = Field(default=None, description="Icon type") + icon: str | None = Field(default=None, description="Icon") + icon_background: str | None = Field(default=None, description="Icon background color") + + @field_validator("description") + @classmethod + def validate_description(cls, value: str | None) -> str | None: + if value is None: + return value + return validate_description_length(value) + + +class AppExportQuery(BaseModel): + include_secret: bool = Field(default=False, description="Include secrets in export") + workflow_id: str | None = Field(default=None, description="Specific workflow ID to export") + + +class AppNamePayload(BaseModel): + name: str = Field(..., min_length=1, description="Name to check") + + +class AppIconPayload(BaseModel): + icon: str | None = Field(default=None, description="Icon data") + icon_background: str | None = Field(default=None, description="Icon background color") + + +class AppSiteStatusPayload(BaseModel): + enable_site: bool = Field(..., description="Enable or disable site") + + +class AppApiStatusPayload(BaseModel): + enable_api: bool = Field(..., description="Enable or disable API") + + +class AppTracePayload(BaseModel): + enabled: bool = Field(..., description="Enable or disable tracing") + tracing_provider: str = Field(..., description="Tracing provider") + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(AppListQuery) +reg(CreateAppPayload) +reg(UpdateAppPayload) +reg(CopyAppPayload) +reg(AppExportQuery) +reg(AppNamePayload) +reg(AppIconPayload) +reg(AppSiteStatusPayload) +reg(AppApiStatusPayload) +reg(AppTracePayload) # Register models for flask_restx to avoid dict type issues in Swagger # Register base models first @@ -147,22 +274,7 @@ app_pagination_model = console_ns.model( class AppListApi(Resource): @console_ns.doc("list_apps") @console_ns.doc(description="Get list of applications with pagination and filtering") - @console_ns.expect( - console_ns.parser() - .add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1) - .add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20) - .add_argument( - "mode", - type=str, - location="args", - choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"], - default="all", - help="App mode filter", - ) - .add_argument("name", type=str, location="args", help="Filter by app name") - .add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs") - .add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator") - ) + @console_ns.expect(console_ns.models[AppListQuery.__name__]) @console_ns.response(200, "Success", app_pagination_model) @setup_required @login_required @@ -172,42 +284,12 @@ class AppListApi(Resource): """Get app list""" current_user, current_tenant_id = current_account_with_tenant() - def uuid_list(value): - try: - return [str(uuid.UUID(v)) for v in value.split(",")] - except ValueError: - abort(400, message="Invalid UUID format in tag_ids.") - - parser = ( - reqparse.RequestParser() - .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - .add_argument( - "mode", - type=str, - choices=[ - "completion", - "chat", - "advanced-chat", - "workflow", - "agent-chat", - "channel", - "all", - ], - default="all", - location="args", - required=False, - ) - .add_argument("name", type=str, location="args", required=False) - .add_argument("tag_ids", type=uuid_list, location="args", required=False) - .add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) - ) - - args = parser.parse_args() + args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_dict = args.model_dump() # get app list app_service = AppService() - app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args) + app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} @@ -254,19 +336,7 @@ class AppListApi(Resource): @console_ns.doc("create_app") @console_ns.doc(description="Create a new application") - @console_ns.expect( - console_ns.model( - "CreateAppRequest", - { - "name": fields.String(required=True, description="App name"), - "description": fields.String(description="App description (max 400 chars)"), - "mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"), - "icon_type": fields.String(description="Icon type"), - "icon": fields.String(description="Icon"), - "icon_background": fields.String(description="Icon background color"), - }, - ) - ) + @console_ns.expect(console_ns.models[CreateAppPayload.__name__]) @console_ns.response(201, "App created successfully", app_detail_model) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @@ -279,22 +349,10 @@ class AppListApi(Resource): def post(self): """Create app""" current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=True, location="json") - .add_argument("description", type=validate_description_length, location="json") - .add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - ) - args = parser.parse_args() - - if "mode" not in args or args["mode"] is None: - raise BadRequest("mode is required") + args = CreateAppPayload.model_validate(console_ns.payload) app_service = AppService() - app = app_service.create_app(current_tenant_id, args, current_user) + app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) return app, 201 @@ -326,20 +384,7 @@ class AppApi(Resource): @console_ns.doc("update_app") @console_ns.doc(description="Update application details") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "UpdateAppRequest", - { - "name": fields.String(required=True, description="App name"), - "description": fields.String(description="App description (max 400 chars)"), - "icon_type": fields.String(description="Icon type"), - "icon": fields.String(description="Icon"), - "icon_background": fields.String(description="Icon background color"), - "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), - "max_active_requests": fields.Integer(description="Maximum active requests"), - }, - ) - ) + @console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) @console_ns.response(200, "App updated successfully", app_detail_with_site_model) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @@ -351,28 +396,18 @@ class AppApi(Resource): @marshal_with(app_detail_with_site_model) def put(self, app_model): """Update app""" - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=validate_description_length, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - .add_argument("use_icon_as_answer_icon", type=bool, location="json") - .add_argument("max_active_requests", type=int, location="json") - ) - args = parser.parse_args() + args = UpdateAppPayload.model_validate(console_ns.payload) app_service = AppService() args_dict: AppService.ArgsDict = { - "name": args["name"], - "description": args.get("description", ""), - "icon_type": args.get("icon_type", ""), - "icon": args.get("icon", ""), - "icon_background": args.get("icon_background", ""), - "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), - "max_active_requests": args.get("max_active_requests", 0), + "name": args.name, + "description": args.description or "", + "icon_type": args.icon_type or "", + "icon": args.icon or "", + "icon_background": args.icon_background or "", + "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, + "max_active_requests": args.max_active_requests or 0, } app_model = app_service.update_app(app_model, args_dict) @@ -401,18 +436,7 @@ class AppCopyApi(Resource): @console_ns.doc("copy_app") @console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(params={"app_id": "Application ID to copy"}) - @console_ns.expect( - console_ns.model( - "CopyAppRequest", - { - "name": fields.String(description="Name for the copied app"), - "description": fields.String(description="Description for the copied app"), - "icon_type": fields.String(description="Icon type"), - "icon": fields.String(description="Icon"), - "icon_background": fields.String(description="Icon background color"), - }, - ) - ) + @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) @console_ns.response(201, "App copied successfully", app_detail_with_site_model) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -426,15 +450,7 @@ class AppCopyApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, location="json") - .add_argument("description", type=validate_description_length, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - ) - args = parser.parse_args() + args = CopyAppPayload.model_validate(console_ns.payload or {}) with Session(db.engine) as session: import_service = AppDslService(session) @@ -443,11 +459,11 @@ class AppCopyApi(Resource): account=current_user, import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content, - name=args.get("name"), - description=args.get("description"), - icon_type=args.get("icon_type"), - icon=args.get("icon"), - icon_background=args.get("icon_background"), + name=args.name, + description=args.description, + icon_type=args.icon_type, + icon=args.icon, + icon_background=args.icon_background, ) session.commit() @@ -462,11 +478,7 @@ class AppExportApi(Resource): @console_ns.doc("export_app") @console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(params={"app_id": "Application ID to export"}) - @console_ns.expect( - console_ns.parser() - .add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export") - .add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export") - ) + @console_ns.expect(console_ns.models[AppExportQuery.__name__]) @console_ns.response( 200, "App exported successfully", @@ -480,30 +492,23 @@ class AppExportApi(Resource): @edit_permission_required def get(self, app_model): """Export app""" - # Add include_secret params - parser = ( - reqparse.RequestParser() - .add_argument("include_secret", type=inputs.boolean, default=False, location="args") - .add_argument("workflow_id", type=str, location="args") - ) - args = parser.parse_args() + args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore return { "data": AppDslService.export_dsl( - app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id") + app_model=app_model, + include_secret=args.include_secret, + workflow_id=args.workflow_id, ) } -parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check") - - @console_ns.route("/apps//name") class AppNameApi(Resource): @console_ns.doc("check_app_name") @console_ns.doc(description="Check if app name is available") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[AppNamePayload.__name__]) @console_ns.response(200, "Name availability checked") @setup_required @login_required @@ -512,10 +517,10 @@ class AppNameApi(Resource): @marshal_with(app_detail_model) @edit_permission_required def post(self, app_model): - args = parser.parse_args() + args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() - app_model = app_service.update_app_name(app_model, args["name"]) + app_model = app_service.update_app_name(app_model, args.name) return app_model @@ -525,16 +530,7 @@ class AppIconApi(Resource): @console_ns.doc("update_app_icon") @console_ns.doc(description="Update application icon") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AppIconRequest", - { - "icon": fields.String(required=True, description="Icon data"), - "icon_type": fields.String(description="Icon type"), - "icon_background": fields.String(description="Icon background color"), - }, - ) - ) + @console_ns.expect(console_ns.models[AppIconPayload.__name__]) @console_ns.response(200, "Icon updated successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -544,15 +540,10 @@ class AppIconApi(Resource): @marshal_with(app_detail_model) @edit_permission_required def post(self, app_model): - parser = ( - reqparse.RequestParser() - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - ) - args = parser.parse_args() + args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") + app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") return app_model @@ -562,11 +553,7 @@ class AppSiteStatus(Resource): @console_ns.doc("update_app_site_status") @console_ns.doc(description="Enable or disable app site") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")} - ) - ) + @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) @console_ns.response(200, "Site status updated successfully", app_detail_model) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -576,11 +563,10 @@ class AppSiteStatus(Resource): @marshal_with(app_detail_model) @edit_permission_required def post(self, app_model): - parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json") - args = parser.parse_args() + args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args["enable_site"]) + app_model = app_service.update_app_site_status(app_model, args.enable_site) return app_model @@ -590,11 +576,7 @@ class AppApiStatus(Resource): @console_ns.doc("update_app_api_status") @console_ns.doc(description="Enable or disable app API") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")} - ) - ) + @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) @console_ns.response(200, "API status updated successfully", app_detail_model) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -604,11 +586,10 @@ class AppApiStatus(Resource): @get_app_model @marshal_with(app_detail_model) def post(self, app_model): - parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json") - args = parser.parse_args() + args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args["enable_api"]) + app_model = app_service.update_app_api_status(app_model, args.enable_api) return app_model @@ -631,15 +612,7 @@ class AppTraceApi(Resource): @console_ns.doc("update_app_trace") @console_ns.doc(description="Update app tracing configuration") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AppTraceRequest", - { - "enabled": fields.Boolean(required=True, description="Enable or disable tracing"), - "tracing_provider": fields.String(required=True, description="Tracing provider"), - }, - ) - ) + @console_ns.expect(console_ns.models[AppTracePayload.__name__]) @console_ns.response(200, "Trace configuration updated successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -648,17 +621,12 @@ class AppTraceApi(Resource): @edit_permission_required def post(self, app_id): # add app trace - parser = ( - reqparse.RequestParser() - .add_argument("enabled", type=bool, required=True, location="json") - .add_argument("tracing_provider", type=str, required=True, location="json") - ) - args = parser.parse_args() + args = AppTracePayload.model_validate(console_ns.payload) OpsTraceManager.update_app_tracing_config( app_id=app_id, - enabled=args["enabled"], - tracing_provider=args["tracing_provider"], + enabled=args.enabled, + tracing_provider=args.tracing_provider, ) return {"result": "success"} diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 2f8429f2ff..2922121a54 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,9 @@ import logging +from typing import Any, Literal from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services @@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class BaseMessagePayload(BaseModel): + inputs: dict[str, Any] + model_config_data: dict[str, Any] = Field(..., alias="model_config") + files: list[Any] | None = Field(default=None, description="Uploaded files") + response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode") + retriever_from: str = Field(default="dev", description="Retriever source") + + +class CompletionMessagePayload(BaseMessagePayload): + query: str = Field(default="", description="Query text") + + +class ChatMessagePayload(BaseMessagePayload): + query: str = Field(..., description="User query") + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +console_ns.schema_model( + CompletionMessagePayload.__name__, + CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) # define completion message api for user @@ -43,19 +80,7 @@ class CompletionMessageApi(Resource): @console_ns.doc("create_completion_message") @console_ns.doc(description="Generate completion message for debugging") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "CompletionMessageRequest", - { - "inputs": fields.Raw(required=True, description="Input variables"), - "query": fields.String(description="Query text", default=""), - "files": fields.List(fields.Raw(), description="Uploaded files"), - "model_config": fields.Raw(required=True, description="Model configuration"), - "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), - "retriever_from": fields.String(default="dev", description="Retriever source"), - }, - ) - ) + @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) @console_ns.response(200, "Completion generated successfully") @console_ns.response(400, "Invalid request parameters") @console_ns.response(404, "App not found") @@ -64,18 +89,10 @@ class CompletionMessageApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("model_config", type=dict, required=True, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="dev", location="json") - ) - args = parser.parse_args() + args_model = CompletionMessagePayload.model_validate(console_ns.payload) + args = args_model.model_dump(exclude_none=True, by_alias=True) - streaming = args["response_mode"] != "blocking" + streaming = args_model.response_mode != "blocking" args["auto_generate_name"] = False try: @@ -137,21 +154,7 @@ class ChatMessageApi(Resource): @console_ns.doc("create_chat_message") @console_ns.doc(description="Generate chat message for debugging") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "ChatMessageRequest", - { - "inputs": fields.Raw(required=True, description="Input variables"), - "query": fields.String(required=True, description="User query"), - "files": fields.List(fields.Raw(), description="Uploaded files"), - "model_config": fields.Raw(required=True, description="Model configuration"), - "conversation_id": fields.String(description="Conversation ID"), - "parent_message_id": fields.String(description="Parent message ID"), - "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), - "retriever_from": fields.String(default="dev", description="Retriever source"), - }, - ) - ) + @console_ns.expect(console_ns.models[ChatMessagePayload.__name__]) @console_ns.response(200, "Chat message generated successfully") @console_ns.response(400, "Invalid request parameters") @console_ns.response(404, "App or conversation not found") @@ -161,20 +164,10 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @edit_permission_required def post(self, app_model): - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("model_config", type=dict, required=True, location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="dev", location="json") - ) - args = parser.parse_args() + args_model = ChatMessagePayload.model_validate(console_ns.payload) + args = args_model.model_dump(exclude_none=True, by_alias=True) - streaming = args["response_mode"] != "blocking" + streaming = args_model.response_mode != "blocking" args["auto_generate_name"] = False external_trace_id = get_external_trace_id(request) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 3d92c46756..9dcadc18a4 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,7 +1,9 @@ +from typing import Literal + import sqlalchemy as sa -from flask import abort -from flask_restx import Resource, fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import abort, request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import NotFound @@ -14,13 +16,54 @@ from extensions.ext_database import db from fields.conversation_fields import MessageTextField from fields.raws import FilesContainedField from libs.datetime_utils import naive_utc_now, parse_time_range -from libs.helper import DatetimeString, TimestampField +from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class BaseConversationQuery(BaseModel): + keyword: str | None = Field(default=None, description="Search keyword") + start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)") + end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)") + annotation_status: Literal["annotated", "not_annotated", "all"] = Field( + default="all", description="Annotation status filter" + ) + page: int = Field(default=1, ge=1, le=99999, description="Page number") + limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)") + + @field_validator("start", "end", mode="before") + @classmethod + def blank_to_none(cls, value: str | None) -> str | None: + if value == "": + return None + return value + + +class CompletionConversationQuery(BaseConversationQuery): + pass + + +class ChatConversationQuery(BaseConversationQuery): + message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count") + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( + default="-updated_at", description="Sort field and direction" + ) + + +console_ns.schema_model( + CompletionConversationQuery.__name__, + CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + ChatConversationQuery.__name__, + ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -283,22 +326,7 @@ class CompletionConversationApi(Resource): @console_ns.doc("list_completion_conversations") @console_ns.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser() - .add_argument("keyword", type=str, location="args", help="Search keyword") - .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") - .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") - .add_argument( - "annotation_status", - type=str, - location="args", - choices=["annotated", "not_annotated", "all"], - default="all", - help="Annotation status filter", - ) - .add_argument("page", type=int, location="args", default=1, help="Page number") - .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") - ) + @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__]) @console_ns.response(200, "Success", conversation_pagination_model) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -309,32 +337,17 @@ class CompletionConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("keyword", type=str, location="args") - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument( - "annotation_status", - type=str, - choices=["annotated", "not_annotated", "all"], - default="all", - location="args", - ) - .add_argument("page", type=int_range(1, 99999), default=1, location="args") - .add_argument("limit", type=int_range(1, 100), default=20, location="args") - ) - args = parser.parse_args() + args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore query = sa.select(Conversation).where( Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) ) - if args["keyword"]: + if args.keyword: query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args['keyword']}%"), - Message.answer.ilike(f"%{args['keyword']}%"), + Message.query.ilike(f"%{args.keyword}%"), + Message.answer.ilike(f"%{args.keyword}%"), ) ) @@ -342,7 +355,7 @@ class CompletionConversationApi(Resource): assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -354,11 +367,11 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) # FIXME, the type ignore in this file - if args["annotation_status"] == "annotated": + if args.annotation_status == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args["annotation_status"] == "not_annotated": + elif args.annotation_status == "not_annotated": query = ( query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) .group_by(Conversation.id) @@ -367,7 +380,7 @@ class CompletionConversationApi(Resource): query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) + conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) return conversations @@ -419,31 +432,7 @@ class ChatConversationApi(Resource): @console_ns.doc("list_chat_conversations") @console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser() - .add_argument("keyword", type=str, location="args", help="Search keyword") - .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") - .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") - .add_argument( - "annotation_status", - type=str, - location="args", - choices=["annotated", "not_annotated", "all"], - default="all", - help="Annotation status filter", - ) - .add_argument("message_count_gte", type=int, location="args", help="Minimum message count") - .add_argument("page", type=int, location="args", default=1, help="Page number") - .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") - .add_argument( - "sort_by", - type=str, - location="args", - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - default="-updated_at", - help="Sort field and direction", - ) - ) + @console_ns.expect(console_ns.models[ChatConversationQuery.__name__]) @console_ns.response(200, "Success", conversation_with_summary_pagination_model) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -454,31 +443,7 @@ class ChatConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("keyword", type=str, location="args") - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument( - "annotation_status", - type=str, - choices=["annotated", "not_annotated", "all"], - default="all", - location="args", - ) - .add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") - .add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() + args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( db.session.query( @@ -490,8 +455,8 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) - if args["keyword"]: - keyword_filter = f"%{args['keyword']}%" + if args.keyword: + keyword_filter = f"%{args.keyword}%" query = ( query.join( Message, @@ -514,12 +479,12 @@ class ChatConversationApi(Resource): assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) if start_datetime_utc: - match args["sort_by"]: + match args.sort_by: case "updated_at" | "-updated_at": query = query.where(Conversation.updated_at >= start_datetime_utc) case "created_at" | "-created_at" | _: @@ -527,35 +492,35 @@ class ChatConversationApi(Resource): if end_datetime_utc: end_datetime_utc = end_datetime_utc.replace(second=59) - match args["sort_by"]: + match args.sort_by: case "updated_at" | "-updated_at": query = query.where(Conversation.updated_at <= end_datetime_utc) case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args["annotation_status"] == "annotated": + if args.annotation_status == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args["annotation_status"] == "not_annotated": + elif args.annotation_status == "not_annotated": query = ( query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) .group_by(Conversation.id) .having(func.count(MessageAnnotation.id) == 0) ) - if args["message_count_gte"] and args["message_count_gte"] >= 1: + if args.message_count_gte and args.message_count_gte >= 1: query = ( query.options(joinedload(Conversation.messages)) # type: ignore .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) - .having(func.count(Message.id) >= args["message_count_gte"]) + .having(func.count(Message.id) >= args.message_count_gte) ) if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) - match args["sort_by"]: + match args.sort_by: case "created_at": query = query.order_by(Conversation.created_at.asc()) case "-created_at": @@ -567,7 +532,7 @@ class ChatConversationApi(Resource): case _: query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) + conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) return conversations diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index c612041fab..368a6112ba 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,6 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,6 +16,18 @@ from libs.login import login_required from models import ConversationVariable from models.model import AppMode +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ConversationVariablesQuery(BaseModel): + conversation_id: str = Field(..., description="Conversation ID to filter variables") + + +console_ns.schema_model( + ConversationVariablesQuery.__name__, + ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + # Register models for flask_restx to avoid dict type issues in Swagger # Register base model first conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) @@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource): @console_ns.doc("get_conversation_variables") @console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser().add_argument( - "conversation_id", type=str, location="args", help="Conversation ID to filter variables" - ) - ) + @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) @setup_required @login_required @@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource): @get_app_model(mode=AppMode.ADVANCED_CHAT) @marshal_with(paginated_conversation_variable_model) def get(self, app_model): - parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args") - args = parser.parse_args() + args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore stmt = ( select(ConversationVariable) .where(ConversationVariable.app_id == app_model.id) .order_by(ConversationVariable.created_at) ) - if args["conversation_id"]: - stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) - else: - raise ValueError("conversation_id is required") + stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id) # NOTE: This is a temporary solution to avoid performance issues. page = 1 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index cf8acda018..b4fc44767a 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,8 @@ from collections.abc import Sequence +from typing import Any -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from controllers.console import console_ns from controllers.console.app.error import ( @@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + + +class InstructionGeneratePayload(BaseModel): + flow_id: str = Field(..., description="Workflow/Flow ID") + node_id: str = Field(default="", description="Node ID for workflow context") + current: str = Field(default="", description="Current instruction text") + language: str = Field(default="javascript", description="Programming language (javascript/python)") + instruction: str = Field(..., description="Instruction for generation") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + ideal_output: str = Field(default="", description="Expected ideal output") + + +class InstructionTemplatePayload(BaseModel): + type: str = Field(..., description="Instruction template type") + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(RuleGeneratePayload) +reg(RuleCodeGeneratePayload) +reg(RuleStructuredOutputPayload) +reg(InstructionGeneratePayload) +reg(InstructionTemplatePayload) + @console_ns.route("/rule-generate") class RuleGenerateApi(Resource): @console_ns.doc("generate_rule_config") @console_ns.doc(description="Generate rule configuration using LLM") - @console_ns.expect( - console_ns.model( - "RuleGenerateRequest", - { - "instruction": fields.String(required=True, description="Rule generation instruction"), - "model_config": fields.Raw(required=True, description="Model configuration"), - "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[RuleGeneratePayload.__name__]) @console_ns.response(200, "Rule configuration generated successfully") @console_ns.response(400, "Invalid request parameters") @console_ns.response(402, "Provider quota exceeded") @@ -43,21 +78,15 @@ class RuleGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("instruction", type=str, required=True, nullable=False, location="json") - .add_argument("model_config", type=dict, required=True, nullable=False, location="json") - .add_argument("no_variable", type=bool, required=True, default=False, location="json") - ) - args = parser.parse_args() + args = RuleGeneratePayload.model_validate(console_ns.payload) _, current_tenant_id = current_account_with_tenant() try: rules = LLMGenerator.generate_rule_config( tenant_id=current_tenant_id, - instruction=args["instruction"], - model_config=args["model_config"], - no_variable=args["no_variable"], + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=args.no_variable, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -75,19 +104,7 @@ class RuleGenerateApi(Resource): class RuleCodeGenerateApi(Resource): @console_ns.doc("generate_rule_code") @console_ns.doc(description="Generate code rules using LLM") - @console_ns.expect( - console_ns.model( - "RuleCodeGenerateRequest", - { - "instruction": fields.String(required=True, description="Code generation instruction"), - "model_config": fields.Raw(required=True, description="Model configuration"), - "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), - "code_language": fields.String( - default="javascript", description="Programming language for code generation" - ), - }, - ) - ) + @console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__]) @console_ns.response(200, "Code rules generated successfully") @console_ns.response(400, "Invalid request parameters") @console_ns.response(402, "Provider quota exceeded") @@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("instruction", type=str, required=True, nullable=False, location="json") - .add_argument("model_config", type=dict, required=True, nullable=False, location="json") - .add_argument("no_variable", type=bool, required=True, default=False, location="json") - .add_argument("code_language", type=str, required=False, default="javascript", location="json") - ) - args = parser.parse_args() + args = RuleCodeGeneratePayload.model_validate(console_ns.payload) _, current_tenant_id = current_account_with_tenant() try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args["instruction"], - model_config=args["model_config"], - code_language=args["code_language"], + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.code_language, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource): class RuleStructuredOutputGenerateApi(Resource): @console_ns.doc("generate_structured_output") @console_ns.doc(description="Generate structured output rules using LLM") - @console_ns.expect( - console_ns.model( - "StructuredOutputGenerateRequest", - { - "instruction": fields.String(required=True, description="Structured output generation instruction"), - "model_config": fields.Raw(required=True, description="Model configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__]) @console_ns.response(200, "Structured output generated successfully") @console_ns.response(400, "Invalid request parameters") @console_ns.response(402, "Provider quota exceeded") @@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("instruction", type=str, required=True, nullable=False, location="json") - .add_argument("model_config", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = RuleStructuredOutputPayload.model_validate(console_ns.payload) _, current_tenant_id = current_account_with_tenant() try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - instruction=args["instruction"], - model_config=args["model_config"], + instruction=args.instruction, + model_config=args.model_config_data, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource): class InstructionGenerateApi(Resource): @console_ns.doc("generate_instruction") @console_ns.doc(description="Generate instruction for workflow nodes or general use") - @console_ns.expect( - console_ns.model( - "InstructionGenerateRequest", - { - "flow_id": fields.String(required=True, description="Workflow/Flow ID"), - "node_id": fields.String(description="Node ID for workflow context"), - "current": fields.String(description="Current instruction text"), - "language": fields.String(default="javascript", description="Programming language (javascript/python)"), - "instruction": fields.String(required=True, description="Instruction for generation"), - "model_config": fields.Raw(required=True, description="Model configuration"), - "ideal_output": fields.String(description="Expected ideal output"), - }, - ) - ) + @console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__]) @console_ns.response(200, "Instruction generated successfully") @console_ns.response(400, "Invalid request parameters or flow/workflow not found") @console_ns.response(402, "Provider quota exceeded") @@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource): @login_required @account_initialization_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("flow_id", type=str, required=True, default="", location="json") - .add_argument("node_id", type=str, required=False, default="", location="json") - .add_argument("current", type=str, required=False, default="", location="json") - .add_argument("language", type=str, required=False, default="javascript", location="json") - .add_argument("instruction", type=str, required=True, nullable=False, location="json") - .add_argument("model_config", type=dict, required=True, nullable=False, location="json") - .add_argument("ideal_output", type=str, required=False, default="", location="json") - ) - args = parser.parse_args() + args = InstructionGeneratePayload.model_validate(console_ns.payload) _, current_tenant_id = current_account_with_tenant() providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] code_provider: type[CodeNodeProvider] | None = next( - (p for p in providers if p.is_accept_language(args["language"])), None + (p for p in providers if p.is_accept_language(args.language)), None ) code_template = code_provider.get_default_code() if code_provider else "" try: # Generate from nothing for a workflow node - if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": - app = db.session.query(App).where(App.id == args["flow_id"]).first() + if (args.current in (code_template, "")) and args.node_id != "": + app = db.session.query(App).where(App.id == args.flow_id).first() if not app: - return {"error": f"app {args['flow_id']} not found"}, 400 + return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) if not workflow: - return {"error": f"workflow {args['flow_id']} not found"}, 400 + return {"error": f"workflow {args.flow_id} not found"}, 400 nodes: Sequence = workflow.graph_dict["nodes"] - node = [node for node in nodes if node["id"] == args["node_id"]] + node = [node for node in nodes if node["id"] == args.node_id] if len(node) == 0: - return {"error": f"node {args['node_id']} not found"}, 400 + return {"error": f"node {args.node_id} not found"}, 400 node_type = node[0]["data"]["type"] match node_type: case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args["instruction"], - model_config=args["model_config"], + instruction=args.instruction, + model_config=args.model_config_data, no_variable=True, ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args["instruction"], - model_config=args["model_config"], + instruction=args.instruction, + model_config=args.model_config_data, no_variable=True, ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args["instruction"], - model_config=args["model_config"], - code_language=args["language"], + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, ) case _: return {"error": f"invalid node type: {node_type}"} - if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow + if args.node_id == "" and args.current != "": # For legacy app without a workflow return LLMGenerator.instruction_modify_legacy( tenant_id=current_tenant_id, - flow_id=args["flow_id"], - current=args["current"], - instruction=args["instruction"], - model_config=args["model_config"], - ideal_output=args["ideal_output"], + flow_id=args.flow_id, + current=args.current, + instruction=args.instruction, + model_config=args.model_config_data, + ideal_output=args.ideal_output, ) - if args["node_id"] != "" and args["current"] != "": # For workflow node + if args.node_id != "" and args.current != "": # For workflow node return LLMGenerator.instruction_modify_workflow( tenant_id=current_tenant_id, - flow_id=args["flow_id"], - node_id=args["node_id"], - current=args["current"], - instruction=args["instruction"], - model_config=args["model_config"], - ideal_output=args["ideal_output"], + flow_id=args.flow_id, + node_id=args.node_id, + current=args.current, + instruction=args.instruction, + model_config=args.model_config_data, + ideal_output=args.ideal_output, workflow_service=WorkflowService(), ) return {"error": "incompatible parameters"}, 400 @@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource): class InstructionGenerationTemplateApi(Resource): @console_ns.doc("get_instruction_template") @console_ns.doc(description="Get instruction generation template") - @console_ns.expect( - console_ns.model( - "InstructionTemplateRequest", - { - "instruction": fields.String(required=True, description="Template instruction"), - "ideal_output": fields.String(description="Expected ideal output"), - }, - ) - ) + @console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__]) @console_ns.response(200, "Template retrieved successfully") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required def post(self): - parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json") - args = parser.parse_args() - match args["type"]: + args = InstructionTemplatePayload.model_validate(console_ns.payload) + match args.type: case "prompt": from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT @@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource): return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} case _: - raise ValueError(f"Invalid type: {args['type']}") + raise ValueError(f"Invalid type: {args.type}") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 40e4020267..377297c84c 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,7 +1,9 @@ import logging +from typing import Literal -from flask_restx import Resource, fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound @@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft from services.message_service import MessageService logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ChatMessagesQuery(BaseModel): + conversation_id: str = Field(..., description="Conversation ID") + first_id: str | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") + + @field_validator("first_id", mode="before") + @classmethod + def empty_to_none(cls, value: str | None) -> str | None: + if value == "": + return None + return value + + @field_validator("conversation_id", "first_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class MessageFeedbackPayload(BaseModel): + message_id: str = Field(..., description="Message ID") + rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str) -> str: + return uuid_value(value) + + +class FeedbackExportQuery(BaseModel): + from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source") + rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating") + has_comment: bool | None = Field(default=None, description="Only include feedback with comments") + start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)") + end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)") + format: Literal["csv", "json"] = Field(default="csv", description="Export format") + + @field_validator("has_comment", mode="before") + @classmethod + def parse_bool(cls, value: bool | str | None) -> bool | None: + if isinstance(value, bool) or value is None: + return value + lowered = value.lower() + if lowered in {"true", "1", "yes", "on"}: + return True + if lowered in {"false", "0", "no", "off"}: + return False + raise ValueError("has_comment must be a boolean value") + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(ChatMessagesQuery) +reg(MessageFeedbackPayload) +reg(FeedbackExportQuery) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -157,12 +220,7 @@ class ChatMessageListApi(Resource): @console_ns.doc("list_chat_messages") @console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser() - .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID") - .add_argument("first_id", type=str, location="args", help="First message ID for pagination") - .add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)") - ) + @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__]) @console_ns.response(200, "Success", message_infinite_scroll_pagination_model) @console_ns.response(404, "Conversation not found") @login_required @@ -172,27 +230,21 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args") - .add_argument("first_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore conversation = ( db.session.query(Conversation) - .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) .first() ) if not conversation: raise NotFound("Conversation Not Exists.") - if args["first_id"]: + if args.first_id: first_message = ( db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .where(Message.conversation_id == conversation.id, Message.id == args.first_id) .first() ) @@ -207,7 +259,7 @@ class ChatMessageListApi(Resource): Message.id != first_message.id, ) .order_by(Message.created_at.desc()) - .limit(args["limit"]) + .limit(args.limit) .all() ) else: @@ -215,12 +267,12 @@ class ChatMessageListApi(Resource): db.session.query(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) - .limit(args["limit"]) + .limit(args.limit) .all() ) # Initialize has_more based on whether we have a full page - if len(history_messages) == args["limit"]: + if len(history_messages) == args.limit: current_page_first_message = history_messages[-1] # Check if there are more messages before the current page has_more = db.session.scalar( @@ -238,7 +290,7 @@ class ChatMessageListApi(Resource): history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) + return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) @console_ns.route("/apps//feedbacks") @@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource): @console_ns.doc("create_message_feedback") @console_ns.doc(description="Create or update message feedback (like/dislike)") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "MessageFeedbackRequest", - { - "message_id": fields.String(required=True, description="Message ID"), - "rating": fields.String(enum=["like", "dislike"], description="Feedback rating"), - }, - ) - ) + @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) @console_ns.response(200, "Feedback updated successfully") @console_ns.response(404, "Message not found") @console_ns.response(403, "Insufficient permissions") @@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource): def post(self, app_model): current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("message_id", required=True, type=uuid_value, location="json") - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - ) - args = parser.parse_args() + args = MessageFeedbackPayload.model_validate(console_ns.payload) - message_id = str(args["message_id"]) + message_id = str(args.message_id) message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() @@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource): feedback = message.admin_feedback - if not args["rating"] and feedback: + if not args.rating and feedback: db.session.delete(feedback) - elif args["rating"] and feedback: - feedback.rating = args["rating"] - elif not args["rating"] and not feedback: + elif args.rating and feedback: + feedback.rating = args.rating + elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: + rating_value = args.rating + if rating_value is None: + raise ValueError("rating is required to create feedback") feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=args["rating"], + rating=rating_value, from_source="admin", from_account_id=current_user.id, ) @@ -369,24 +411,12 @@ class MessageSuggestedQuestionApi(Resource): return {"data": questions} -# Shared parser for feedback export (used for both documentation and runtime parsing) -feedback_export_parser = ( - console_ns.parser() - .add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source") - .add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating") - .add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments") - .add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)") - .add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)") - .add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format") -) - - @console_ns.route("/apps//feedbacks/export") class MessageFeedbackExportApi(Resource): @console_ns.doc("export_feedbacks") @console_ns.doc(description="Export user feedback data for Google Sheets") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(feedback_export_parser) + @console_ns.expect(console_ns.models[FeedbackExportQuery.__name__]) @console_ns.response(200, "Feedback data exported successfully") @console_ns.response(400, "Invalid parameters") @console_ns.response(500, "Internal server error") @@ -395,7 +425,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = feedback_export_parser.parse_args() + args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore # Import the service function from services.feedback_service import FeedbackService @@ -403,12 +433,12 @@ class MessageFeedbackExportApi(Resource): try: export_data = FeedbackService.export_feedbacks( app_id=app_model.id, - from_source=args.get("from_source"), - rating=args.get("rating"), - has_comment=args.get("has_comment"), - start_date=args.get("start_date"), - end_date=args.get("end_date"), - format_type=args.get("format", "csv"), + from_source=args.from_source, + rating=args.rating, + has_comment=args.has_comment, + start_date=args.start_date, + end_date=args.end_date, + format_type=args.format, ) return export_data diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index c8f54c638e..ffa28b1c95 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -1,8 +1,9 @@ from decimal import Decimal import sqlalchemy as sa -from flask import abort, jsonify -from flask_restx import Resource, fields, reqparse +from flask import abort, jsonify, request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.datetime_utils import parse_time_range -from libs.helper import DatetimeString, convert_datetime_to_date +from libs.helper import convert_datetime_to_date from libs.login import current_account_with_tenant, login_required from models import AppMode +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class StatisticTimeRangeQuery(BaseModel): + start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)") + end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)") + + @field_validator("start", "end", mode="before") + @classmethod + def empty_string_to_none(cls, value: str | None) -> str | None: + if value == "": + return None + return value + + +console_ns.schema_model( + StatisticTimeRangeQuery.__name__, + StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/apps//statistics/daily-messages") class DailyMessageStatistic(Resource): @console_ns.doc("get_daily_message_statistics") @console_ns.doc(description="Get daily message statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser() - .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") - .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") - ) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Daily message statistics retrieved successfully", @@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - ) - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -57,7 +69,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -81,19 +93,12 @@ WHERE return jsonify({"data": response_data}) -parser = ( - reqparse.RequestParser() - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)") -) - - @console_ns.route("/apps//statistics/daily-conversations") class DailyConversationStatistic(Resource): @console_ns.doc("get_daily_conversation_statistics") @console_ns.doc(description="Get daily conversation statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Daily conversation statistics retrieved successfully", @@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -121,7 +126,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource): @console_ns.doc("get_daily_terminals_statistics") @console_ns.doc(description="Get daily terminal/end-user statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Daily terminal statistics retrieved successfully", @@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -177,7 +182,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource): @console_ns.doc("get_daily_token_cost_statistics") @console_ns.doc(description="Get daily token cost statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Daily token cost statistics retrieved successfully", @@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -235,7 +240,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource): @console_ns.doc("get_average_session_interaction_statistics") @console_ns.doc(description="Get average session interaction statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Average session interaction statistics retrieved successfully", @@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("c.created_at") sql_query = f"""SELECT @@ -302,7 +307,7 @@ FROM assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource): @console_ns.doc("get_user_satisfaction_rate_statistics") @console_ns.doc(description="Get user satisfaction rate statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "User satisfaction rate statistics retrieved successfully", @@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("m.created_at") sql_query = f"""SELECT @@ -374,7 +379,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource): @console_ns.doc("get_average_response_time_statistics") @console_ns.doc(description="Get average response time statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Average response time statistics retrieved successfully", @@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -436,7 +441,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource): @console_ns.doc("get_tokens_per_second_statistics") @console_ns.doc(description="Get tokens per second statistics for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @console_ns.response( 200, "Tokens per second statistics retrieved successfully", @@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource): @account_initialization_required def get(self, app_model): account, _ = current_account_with_tenant() - args = parser.parse_args() + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -495,7 +500,7 @@ WHERE assert account.timezone is not None try: - start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) + start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 0082089365..b4f2ef0ba8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,10 +1,11 @@ import json import logging from collections.abc import Sequence -from typing import cast +from typing import Any from flask import abort, request -from flask_restx import Resource, fields, inputs, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE logger = logging.getLogger(__name__) LISTENING_RETRY_IN = 2000 +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None: workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields) +class SyncDraftWorkflowPayload(BaseModel): + graph: dict[str, Any] + features: dict[str, Any] + hash: str | None = None + environment_variables: list[dict[str, Any]] = Field(default_factory=list) + conversation_variables: list[dict[str, Any]] = Field(default_factory=list) + + +class BaseWorkflowRunPayload(BaseModel): + files: list[dict[str, Any]] | None = None + + +class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload): + inputs: dict[str, Any] | None = None + query: str = "" + conversation_id: str | None = None + parent_message_id: str | None = None + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class IterationNodeRunPayload(BaseModel): + inputs: dict[str, Any] | None = None + + +class LoopNodeRunPayload(BaseModel): + inputs: dict[str, Any] | None = None + + +class DraftWorkflowRunPayload(BaseWorkflowRunPayload): + inputs: dict[str, Any] + + +class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload): + inputs: dict[str, Any] + query: str = "" + + +class PublishWorkflowPayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + +class DefaultBlockConfigQuery(BaseModel): + q: str | None = None + + +class ConvertToWorkflowPayload(BaseModel): + name: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + + +class WorkflowListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=10, ge=1, le=100) + user_id: str | None = None + named_only: bool = False + + +class WorkflowUpdatePayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + +class DraftWorkflowTriggerRunPayload(BaseModel): + node_id: str + + +class DraftWorkflowTriggerRunAllPayload(BaseModel): + node_ids: list[str] + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(SyncDraftWorkflowPayload) +reg(AdvancedChatWorkflowRunPayload) +reg(IterationNodeRunPayload) +reg(LoopNodeRunPayload) +reg(DraftWorkflowRunPayload) +reg(DraftWorkflowNodeRunPayload) +reg(PublishWorkflowPayload) +reg(DefaultBlockConfigQuery) +reg(ConvertToWorkflowPayload) +reg(WorkflowListQuery) +reg(WorkflowUpdatePayload) +reg(DraftWorkflowTriggerRunPayload) +reg(DraftWorkflowTriggerRunAllPayload) + + # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing # at the controller level rather than in the workflow logic. This would improve separation # of concerns and make the code more maintainable. @@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource): @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @console_ns.doc("sync_draft_workflow") @console_ns.doc(description="Sync draft workflow configuration") - @console_ns.expect( - console_ns.model( - "SyncDraftWorkflowRequest", - { - "graph": fields.Raw(required=True, description="Workflow graph configuration"), - "features": fields.Raw(required=True, description="Workflow features configuration"), - "hash": fields.String(description="Workflow hash for validation"), - "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"), - "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__]) @console_ns.response( 200, "Draft workflow synced successfully", @@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource): content_type = request.headers.get("Content-Type", "") + payload_data: dict[str, Any] | None = None if "application/json" in content_type: - parser = ( - reqparse.RequestParser() - .add_argument("graph", type=dict, required=True, nullable=False, location="json") - .add_argument("features", type=dict, required=True, nullable=False, location="json") - .add_argument("hash", type=str, required=False, location="json") - .add_argument("environment_variables", type=list, required=True, location="json") - .add_argument("conversation_variables", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload_data = request.get_json(silent=True) + if not isinstance(payload_data, dict): + return {"message": "Invalid JSON data"}, 400 elif "text/plain" in content_type: try: - data = json.loads(request.data.decode("utf-8")) - if "graph" not in data or "features" not in data: - raise ValueError("graph or features not found in data") - - if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): - raise ValueError("graph or features is not a dict") - - args = { - "graph": data.get("graph"), - "features": data.get("features"), - "hash": data.get("hash"), - "environment_variables": data.get("environment_variables"), - "conversation_variables": data.get("conversation_variables"), - } + payload_data = json.loads(request.data.decode("utf-8")) except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 + if not isinstance(payload_data, dict): + return {"message": "Invalid JSON data"}, 400 else: abort(415) + + args_model = SyncDraftWorkflowPayload.model_validate(payload_data) + args = args_model.model_dump() workflow_service = WorkflowService() try: @@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): @console_ns.doc("run_advanced_chat_draft_workflow") @console_ns.doc(description="Run draft workflow for advanced chat application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AdvancedChatWorkflowRunRequest", - { - "query": fields.String(required=True, description="User query"), - "inputs": fields.Raw(description="Input variables"), - "files": fields.List(fields.Raw, description="File uploads"), - "conversation_id": fields.String(description="Conversation ID"), - }, - ) - ) + @console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__]) @console_ns.response(200, "Workflow run started successfully") @console_ns.response(400, "Invalid request parameters") @console_ns.response(403, "Permission denied") @@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource): """ current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, location="json") - .add_argument("query", type=str, required=True, location="json", default="") - .add_argument("files", type=list, location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - ) - - args = parser.parse_args() + args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {}) + args = args_model.model_dump(exclude_none=True) external_trace_id = get_external_trace_id(request) if external_trace_id: @@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): @console_ns.doc("run_advanced_chat_draft_iteration_node") @console_ns.doc(description="Run draft workflow iteration node for advanced chat") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect( - console_ns.model( - "IterationNodeRunRequest", - { - "task_id": fields.String(required=True, description="Task ID"), - "inputs": fields.Raw(description="Input variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__]) @console_ns.response(200, "Iteration node run started successfully") @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") - args = parser.parse_args() + args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: response = AppGenerateService.generate_single_iteration( @@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): @console_ns.doc("run_workflow_draft_iteration_node") @console_ns.doc(description="Run draft workflow iteration node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect( - console_ns.model( - "WorkflowIterationNodeRunRequest", - { - "task_id": fields.String(required=True, description="Task ID"), - "inputs": fields.Raw(description="Input variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__]) @console_ns.response(200, "Workflow iteration node run started successfully") @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") - args = parser.parse_args() + args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: response = AppGenerateService.generate_single_iteration( @@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): @console_ns.doc("run_advanced_chat_draft_loop_node") @console_ns.doc(description="Run draft workflow loop node for advanced chat") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect( - console_ns.model( - "LoopNodeRunRequest", - { - "task_id": fields.String(required=True, description="Task ID"), - "inputs": fields.Raw(description="Input variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__]) @console_ns.response(200, "Loop node run started successfully") @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") - args = parser.parse_args() + args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: response = AppGenerateService.generate_single_loop( @@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): @console_ns.doc("run_workflow_draft_loop_node") @console_ns.doc(description="Run draft workflow loop node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect( - console_ns.model( - "WorkflowLoopNodeRunRequest", - { - "task_id": fields.String(required=True, description="Task ID"), - "inputs": fields.Raw(description="Input variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__]) @console_ns.response(200, "Workflow loop node run started successfully") @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") - args = parser.parse_args() + args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: response = AppGenerateService.generate_single_loop( @@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource): @console_ns.doc("run_draft_workflow") @console_ns.doc(description="Run draft workflow") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "DraftWorkflowRunRequest", - { - "inputs": fields.Raw(required=True, description="Input variables"), - "files": fields.List(fields.Raw, description="File uploads"), - }, - ) - ) + @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__]) @console_ns.response(200, "Draft workflow run started successfully") @console_ns.response(403, "Permission denied") @setup_required @@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource): Run draft workflow """ current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - ) - args = parser.parse_args() + args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) external_trace_id = get_external_trace_id(request) if external_trace_id: @@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource): @console_ns.doc("run_draft_workflow_node") @console_ns.doc(description="Run draft workflow node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect( - console_ns.model( - "DraftWorkflowNodeRunRequest", - { - "inputs": fields.Raw(description="Input variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__]) @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model) @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource): Run draft workflow node """ current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("query", type=str, required=False, location="json", default="") - .add_argument("files", type=list, location="json", default=[]) - ) - args = parser.parse_args() + args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {}) + args = args_model.model_dump(exclude_none=True) - user_inputs = args.get("inputs") + user_inputs = args_model.inputs if user_inputs is None: raise ValueError("missing inputs") @@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource): return workflow_node_execution -parser_publish = ( - reqparse.RequestParser() - .add_argument("marked_name", type=str, required=False, default="", location="json") - .add_argument("marked_comment", type=str, required=False, default="", location="json") -) - - @console_ns.route("/apps//workflows/publish") class PublishedWorkflowApi(Resource): @console_ns.doc("get_published_workflow") @@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource): # return workflow, if not found, return None return workflow - @console_ns.expect(parser_publish) + @console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - args = parser_publish.parse_args() - - # Validate name and comment length - if args.marked_name and len(args.marked_name) > 20: - raise ValueError("Marked name cannot exceed 20 characters") - if args.marked_comment and len(args.marked_comment) > 100: - raise ValueError("Marked comment cannot exceed 100 characters") + args = PublishWorkflowPayload.model_validate(console_ns.payload or {}) workflow_service = WorkflowService() with Session(db.engine) as session: @@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource): return workflow_service.get_default_block_configs() -parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args") - - @console_ns.route("/apps//workflows/default-workflow-block-configs/") class DefaultBlockConfigApi(Resource): @console_ns.doc("get_default_block_config") @@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource): @console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"}) @console_ns.response(200, "Default block configuration retrieved successfully") @console_ns.response(404, "Block type not found") - @console_ns.expect(parser_block) + @console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__]) @setup_required @login_required @account_initialization_required @@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ - args = parser_block.parse_args() - - q = args.get("q") + args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore filters = None - if q: + if args.q: try: - filters = json.loads(args.get("q", "")) + filters = json.loads(args.q) except json.JSONDecodeError: raise ValueError("Invalid filters") @@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource): return workflow_service.get_default_block_config(node_type=block_type, filters=filters) -parser_convert = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, nullable=True, location="json") - .add_argument("icon_type", type=str, required=False, nullable=True, location="json") - .add_argument("icon", type=str, required=False, nullable=True, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/apps//convert-to-workflow") class ConvertToWorkflowApi(Resource): - @console_ns.expect(parser_convert) + @console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__]) @console_ns.doc("convert_to_workflow") @console_ns.doc(description="Convert application to workflow mode") @console_ns.doc(params={"app_id": "Application ID"}) @@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - if request.data: - args = parser_convert.parse_args() - else: - args = {} + payload = console_ns.payload or {} + args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True) # convert to workflow mode workflow_service = WorkflowService() @@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource): } -parser_workflows = ( - reqparse.RequestParser() - .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args") - .add_argument("user_id", type=str, required=False, location="args") - .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") -) - - @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): - @console_ns.expect(parser_workflows) + @console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) @console_ns.doc("get_all_published_workflows") @console_ns.doc(description="Get all published workflows for an application") @console_ns.doc(params={"app_id": "Application ID"}) @@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - args = parser_workflows.parse_args() - page = args["page"] - limit = args["limit"] - user_id = args.get("user_id") - named_only = args.get("named_only", False) + args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + page = args.page + limit = args.limit + user_id = args.user_id + named_only = args.named_only if user_id: if user_id != current_user.id: raise Forbidden() - user_id = cast(str, user_id) workflow_service = WorkflowService() with Session(db.engine) as session: @@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource): @console_ns.doc("update_workflow_by_id") @console_ns.doc(description="Update workflow by ID") @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) - @console_ns.expect( - console_ns.model( - "UpdateWorkflowRequest", - { - "environment_variables": fields.List(fields.Raw, description="Environment variables"), - "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), - }, - ) - ) + @console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__]) @console_ns.response(200, "Workflow updated successfully", workflow_model) @console_ns.response(404, "Workflow not found") @console_ns.response(403, "Permission denied") @@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource): Update workflow attributes """ current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("marked_name", type=str, required=False, location="json") - .add_argument("marked_comment", type=str, required=False, location="json") - ) - args = parser.parse_args() - - # Validate name and comment length - if args.marked_name and len(args.marked_name) > 20: - raise ValueError("Marked name cannot exceed 20 characters") - if args.marked_comment and len(args.marked_comment) > 100: - raise ValueError("Marked comment cannot exceed 100 characters") + args = WorkflowUpdatePayload.model_validate(console_ns.payload or {}) # Prepare update data update_data = {} - if args.get("marked_name") is not None: - update_data["marked_name"] = args["marked_name"] - if args.get("marked_comment") is not None: - update_data["marked_comment"] = args["marked_comment"] + if args.marked_name is not None: + update_data["marked_name"] = args.marked_name + if args.marked_comment is not None: + update_data["marked_comment"] = args.marked_comment if not update_data: return {"message": "No valid fields to update"}, 400 @@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource): Poll for trigger events and execute full workflow when event arrives """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument( - "node_id", type=str, required=True, location="json", nullable=False - ) - args = parser.parse_args() - node_id = args["node_id"] + args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {}) + node_id = args.node_id workflow_service = WorkflowService() draft_workflow = workflow_service.get_draft_workflow(app_model) if not draft_workflow: @@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource): @console_ns.doc("draft_workflow_trigger_run_all") @console_ns.doc(description="Full workflow debug when the start node is a trigger") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "DraftWorkflowTriggerRunAllRequest", - { - "node_ids": fields.List(fields.String, required=True, description="Node IDs"), - }, - ) - ) + @console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__]) @console_ns.response(200, "Workflow executed successfully") @console_ns.response(403, "Permission denied") @console_ns.response(500, "Internal server error") @@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource): """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument( - "node_ids", type=list, required=True, location="json", nullable=False - ) - args = parser.parse_args() - node_ids = args["node_ids"] + args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {}) + node_ids = args.node_ids workflow_service = WorkflowService() draft_workflow = workflow_service.get_draft_workflow(app_model) if not draft_workflow: diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 677678cb8f..fa67fb8154 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,6 +1,9 @@ +from datetime import datetime + from dateutil.parser import isoparse -from flask_restx import Resource, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from controllers.console import console_ns @@ -14,6 +17,48 @@ from models import App from models.model import AppMode from services.workflow_app_service import WorkflowAppService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowAppLogQuery(BaseModel): + keyword: str | None = Field(default=None, description="Search keyword for filtering logs") + status: WorkflowExecutionStatus | None = Field( + default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)" + ) + created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp") + created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp") + created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID") + created_by_account: str | None = Field(default=None, description="Filter by account") + detail: bool = Field(default=False, description="Whether to return detailed logs") + page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") + limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)") + + @field_validator("created_at__before", "created_at__after", mode="before") + @classmethod + def parse_datetime(cls, value: str | None) -> datetime | None: + if value in (None, ""): + return None + return isoparse(value) # type: ignore + + @field_validator("detail", mode="before") + @classmethod + def parse_bool(cls, value: bool | str | None) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + lowered = value.lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + raise ValueError("Invalid boolean value for detail") + + +console_ns.schema_model( + WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + # Register model for flask_restx to avoid dict type issues in Swagger workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) @@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource): @console_ns.doc("get_workflow_app_logs") @console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={ - "keyword": "Search keyword for filtering logs", - "status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)", - "created_at__before": "Filter logs created before this timestamp", - "created_at__after": "Filter logs created after this timestamp", - "created_by_end_user_session_id": "Filter by end user session ID", - "created_by_account": "Filter by account", - "detail": "Whether to return detailed logs", - "page": "Page number (1-99999)", - "limit": "Number of items per page (1-100)", - } - ) + @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) @setup_required @login_required @@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource): """ Get workflow app logs """ - parser = ( - reqparse.RequestParser() - .add_argument("keyword", type=str, location="args") - .add_argument( - "status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args" - ) - .add_argument( - "created_at__before", type=str, location="args", help="Filter logs created before this timestamp" - ) - .add_argument( - "created_at__after", type=str, location="args", help="Filter logs created after this timestamp" - ) - .add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - .add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - .add_argument("detail", type=bool, location="args", required=False, default=False) - .add_argument("page", type=int_range(1, 99999), default=1, location="args") - .add_argument("limit", type=int_range(1, 100), default=20, location="args") - ) - args = parser.parse_args() - - args.status = WorkflowExecutionStatus(args.status) if args.status else None - if args.created_at__before: - args.created_at__before = isoparse(args.created_at__before) - - if args.created_at__after: - args.created_at__after = isoparse(args.created_at__after) + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore # get paginate workflow app logs workflow_app_service = WorkflowAppService() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index c016104ce0..8f1871f1e9 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,7 +1,8 @@ -from typing import cast +from typing import Literal, cast -from flask_restx import Resource, fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model( "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy ) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -def _parse_workflow_run_list_args(): - """ - Parse common arguments for workflow run list endpoints. - Returns: - Parsed arguments containing last_id, limit, status, and triggered_from filters - """ - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument( - "status", - type=str, - choices=WORKFLOW_RUN_STATUS_CHOICES, - location="args", - required=False, - ) - .add_argument( - "triggered_from", - type=str, - choices=["debugging", "app-run"], - location="args", - required=False, - help="Filter by trigger source: debugging or app-run", - ) +class WorkflowRunListQuery(BaseModel): + last_id: str | None = Field(default=None, description="Last run ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)") + status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field( + default=None, description="Workflow run status filter" ) - return parser.parse_args() - - -def _parse_workflow_run_count_args(): - """ - Parse common arguments for workflow run count endpoints. - - Returns: - Parsed arguments containing status, time_range, and triggered_from filters - """ - parser = ( - reqparse.RequestParser() - .add_argument( - "status", - type=str, - choices=WORKFLOW_RUN_STATUS_CHOICES, - location="args", - required=False, - ) - .add_argument( - "time_range", - type=time_duration, - location="args", - required=False, - help="Time range filter (e.g., 7d, 4h, 30m, 30s)", - ) - .add_argument( - "triggered_from", - type=str, - choices=["debugging", "app-run"], - location="args", - required=False, - help="Filter by trigger source: debugging or app-run", - ) + triggered_from: Literal["debugging", "app-run"] | None = Field( + default=None, description="Filter by trigger source: debugging or app-run" ) - return parser.parse_args() + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class WorkflowRunCountQuery(BaseModel): + status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field( + default=None, description="Workflow run status filter" + ) + time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)") + triggered_from: Literal["debugging", "app-run"] | None = Field( + default=None, description="Filter by trigger source: debugging or app-run" + ) + + @field_validator("time_range") + @classmethod + def validate_time_range(cls, value: str | None) -> str | None: + if value is None: + return value + return time_duration(value) + + +console_ns.schema_model( + WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + WorkflowRunCountQuery.__name__, + WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) @console_ns.route("/apps//advanced-chat/workflow-runs") @@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc( params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} ) + @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) @console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model) @setup_required @login_required @@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource): """ Get advanced chat app workflow run list """ - args = _parse_workflow_run_list_args() + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING if not specified triggered_from = ( - WorkflowRunTriggeredFrom(args.get("triggered_from")) - if args.get("triggered_from") + WorkflowRunTriggeredFrom(args_model.triggered_from) + if args_model.triggered_from else WorkflowRunTriggeredFrom.DEBUGGING ) @@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} ) @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) + @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__]) @setup_required @login_required @account_initialization_required @@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): """ Get advanced chat workflow runs count statistics """ - args = _parse_workflow_run_count_args() + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING if not specified triggered_from = ( - WorkflowRunTriggeredFrom(args.get("triggered_from")) - if args.get("triggered_from") + WorkflowRunTriggeredFrom(args_model.triggered_from) + if args_model.triggered_from else WorkflowRunTriggeredFrom.DEBUGGING ) @@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource): params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} ) @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) + @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) @setup_required @login_required @account_initialization_required @@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource): """ Get workflow run list """ - args = _parse_workflow_run_list_args() + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING for workflow if not specified (backward compatibility) triggered_from = ( - WorkflowRunTriggeredFrom(args.get("triggered_from")) - if args.get("triggered_from") + WorkflowRunTriggeredFrom(args_model.triggered_from) + if args_model.triggered_from else WorkflowRunTriggeredFrom.DEBUGGING ) @@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource): params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} ) @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) + @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__]) @setup_required @login_required @account_initialization_required @@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource): """ Get workflow runs count statistics """ - args = _parse_workflow_run_count_args() + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING for workflow if not specified (backward compatibility) triggered_from = ( - WorkflowRunTriggeredFrom(args.get("triggered_from")) - if args.get("triggered_from") + WorkflowRunTriggeredFrom(args_model.triggered_from) + if args_model.triggered_from else WorkflowRunTriggeredFrom.DEBUGGING ) diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 4a873e5ec1..e48cf42762 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -1,5 +1,6 @@ -from flask import abort, jsonify -from flask_restx import Resource, reqparse +from flask import abort, jsonify, request +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker from controllers.console import console_ns @@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from libs.datetime_utils import parse_time_range -from libs.helper import DatetimeString from libs.login import current_account_with_tenant, login_required from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowStatisticQuery(BaseModel): + start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)") + end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)") + + @field_validator("start", "end", mode="before") + @classmethod + def blank_to_none(cls, value: str | None) -> str | None: + if value == "": + return None + return value + + +console_ns.schema_model( + WorkflowStatisticQuery.__name__, + WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/apps//workflow/statistics/daily-conversations") class WorkflowDailyRunsStatistic(Resource): @@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource): @console_ns.doc("get_workflow_daily_runs_statistic") @console_ns.doc(description="Get workflow daily runs statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"} - ) + @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @console_ns.response(200, "Daily runs statistics retrieved successfully") @get_app_model @setup_required @@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - ) - args = parser.parse_args() + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore assert account.timezone is not None try: - start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) + start_date, end_date = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource): @console_ns.doc("get_workflow_daily_terminals_statistic") @console_ns.doc(description="Get workflow daily terminals statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"} - ) + @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @console_ns.response(200, "Daily terminals statistics retrieved successfully") @get_app_model @setup_required @@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - ) - args = parser.parse_args() + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore assert account.timezone is not None try: - start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) + start_date, end_date = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource): @console_ns.doc("get_workflow_daily_token_cost_statistic") @console_ns.doc(description="Get workflow daily token cost statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"} - ) + @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @console_ns.response(200, "Daily token cost statistics retrieved successfully") @get_app_model @setup_required @@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - ) - args = parser.parse_args() + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore assert account.timezone is not None try: - start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) + start_date, end_date = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) @@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): @console_ns.doc("get_workflow_average_app_interaction_statistic") @console_ns.doc(description="Get workflow average app interaction statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"} - ) + @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @console_ns.response(200, "Average app interaction statistics retrieved successfully") @setup_required @login_required @@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") - ) - args = parser.parse_args() + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore assert account.timezone is not None try: - start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) + start_date, end_date = parse_time_range(args.start, args.end, account.timezone) except ValueError as e: abort(400, description=str(e)) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index b4d1b42657..6334314988 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -174,63 +174,25 @@ class CheckEmailUniquePayload(BaseModel): return email(value) -console_ns.schema_model( - AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - AccountInterfaceLanguagePayload.__name__, - AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - AccountInterfaceThemePayload.__name__, - AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - AccountTimezonePayload.__name__, - AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - AccountPasswordPayload.__name__, - AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - AccountDeletePayload.__name__, - AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - AccountDeletionFeedbackPayload.__name__, - AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - EducationActivatePayload.__name__, - EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - EducationAutocompleteQuery.__name__, - EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChangeEmailSendPayload.__name__, - ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChangeEmailValidityPayload.__name__, - ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChangeEmailResetPayload.__name__, - ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - CheckEmailUniquePayload.__name__, - CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(AccountInitPayload) +reg(AccountNamePayload) +reg(AccountAvatarPayload) +reg(AccountInterfaceLanguagePayload) +reg(AccountInterfaceThemePayload) +reg(AccountTimezonePayload) +reg(AccountPasswordPayload) +reg(AccountDeletePayload) +reg(AccountDeletionFeedbackPayload) +reg(EducationActivatePayload) +reg(EducationAutocompleteQuery) +reg(ChangeEmailSendPayload) +reg(ChangeEmailValidityPayload) +reg(ChangeEmailResetPayload) +reg(CheckEmailUniquePayload) @console_ns.route("/account/init") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 7216b5e0e7..bfd9fc6c29 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,4 +1,8 @@ -from flask_restx import Resource, fields, reqparse +from typing import Any + +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required @@ -7,21 +11,49 @@ from core.plugin.impl.exc import PluginPermissionDeniedError from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class EndpointCreatePayload(BaseModel): + plugin_unique_identifier: str + settings: dict[str, Any] + name: str = Field(min_length=1) + + +class EndpointIdPayload(BaseModel): + endpoint_id: str + + +class EndpointUpdatePayload(EndpointIdPayload): + settings: dict[str, Any] + name: str = Field(min_length=1) + + +class EndpointListQuery(BaseModel): + page: int = Field(ge=1) + page_size: int = Field(gt=0) + + +class EndpointListForPluginQuery(EndpointListQuery): + plugin_id: str + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(EndpointCreatePayload) +reg(EndpointIdPayload) +reg(EndpointUpdatePayload) +reg(EndpointListQuery) +reg(EndpointListForPluginQuery) + @console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): @console_ns.doc("create_endpoint") @console_ns.doc(description="Create a new plugin endpoint") - @console_ns.expect( - console_ns.model( - "EndpointCreateRequest", - { - "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"), - "settings": fields.Raw(required=True, description="Endpoint settings"), - "name": fields.String(required=True, description="Endpoint name"), - }, - ) - ) + @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__]) @console_ns.response( 200, "Endpoint created successfully", @@ -35,26 +67,16 @@ class EndpointCreateApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("plugin_unique_identifier", type=str, required=True) - .add_argument("settings", type=dict, required=True) - .add_argument("name", type=str, required=True) - ) - args = parser.parse_args() - - plugin_unique_identifier = args["plugin_unique_identifier"] - settings = args["settings"] - name = args["name"] + args = EndpointCreatePayload.model_validate(console_ns.payload) try: return { "success": EndpointService.create_endpoint( tenant_id=tenant_id, user_id=user.id, - plugin_unique_identifier=plugin_unique_identifier, - name=name, - settings=settings, + plugin_unique_identifier=args.plugin_unique_identifier, + name=args.name, + settings=args.settings, ) } except PluginPermissionDeniedError as e: @@ -65,11 +87,7 @@ class EndpointCreateApi(Resource): class EndpointListApi(Resource): @console_ns.doc("list_endpoints") @console_ns.doc(description="List plugin endpoints with pagination") - @console_ns.expect( - console_ns.parser() - .add_argument("page", type=int, required=True, location="args", help="Page number") - .add_argument("page_size", type=int, required=True, location="args", help="Page size") - ) + @console_ns.expect(console_ns.models[EndpointListQuery.__name__]) @console_ns.response( 200, "Success", @@ -83,15 +101,10 @@ class EndpointListApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("page", type=int, required=True, location="args") - .add_argument("page_size", type=int, required=True, location="args") - ) - args = parser.parse_args() + args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - page = args["page"] - page_size = args["page_size"] + page = args.page + page_size = args.page_size return jsonable_encoder( { @@ -109,12 +122,7 @@ class EndpointListApi(Resource): class EndpointListForSinglePluginApi(Resource): @console_ns.doc("list_plugin_endpoints") @console_ns.doc(description="List endpoints for a specific plugin") - @console_ns.expect( - console_ns.parser() - .add_argument("page", type=int, required=True, location="args", help="Page number") - .add_argument("page_size", type=int, required=True, location="args", help="Page size") - .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID") - ) + @console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__]) @console_ns.response( 200, "Success", @@ -128,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("page", type=int, required=True, location="args") - .add_argument("page_size", type=int, required=True, location="args") - .add_argument("plugin_id", type=str, required=True, location="args") - ) - args = parser.parse_args() + args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - page = args["page"] - page_size = args["page_size"] - plugin_id = args["plugin_id"] + page = args.page + page_size = args.page_size + plugin_id = args.plugin_id return jsonable_encoder( { @@ -157,11 +159,7 @@ class EndpointListForSinglePluginApi(Resource): class EndpointDeleteApi(Resource): @console_ns.doc("delete_endpoint") @console_ns.doc(description="Delete a plugin endpoint") - @console_ns.expect( - console_ns.model( - "EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")} - ) - ) + @console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @console_ns.response( 200, "Endpoint deleted successfully", @@ -175,13 +173,12 @@ class EndpointDeleteApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) - args = parser.parse_args() - - endpoint_id = args["endpoint_id"] + args = EndpointIdPayload.model_validate(console_ns.payload) return { - "success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) + "success": EndpointService.delete_endpoint( + tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id + ) } @@ -189,16 +186,7 @@ class EndpointDeleteApi(Resource): class EndpointUpdateApi(Resource): @console_ns.doc("update_endpoint") @console_ns.doc(description="Update a plugin endpoint") - @console_ns.expect( - console_ns.model( - "EndpointUpdateRequest", - { - "endpoint_id": fields.String(required=True, description="Endpoint ID"), - "settings": fields.Raw(required=True, description="Updated settings"), - "name": fields.String(required=True, description="Updated name"), - }, - ) - ) + @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__]) @console_ns.response( 200, "Endpoint updated successfully", @@ -212,25 +200,15 @@ class EndpointUpdateApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("endpoint_id", type=str, required=True) - .add_argument("settings", type=dict, required=True) - .add_argument("name", type=str, required=True) - ) - args = parser.parse_args() - - endpoint_id = args["endpoint_id"] - settings = args["settings"] - name = args["name"] + args = EndpointUpdatePayload.model_validate(console_ns.payload) return { "success": EndpointService.update_endpoint( tenant_id=tenant_id, user_id=user.id, - endpoint_id=endpoint_id, - name=name, - settings=settings, + endpoint_id=args.endpoint_id, + name=args.name, + settings=args.settings, ) } @@ -239,11 +217,7 @@ class EndpointUpdateApi(Resource): class EndpointEnableApi(Resource): @console_ns.doc("enable_endpoint") @console_ns.doc(description="Enable a plugin endpoint") - @console_ns.expect( - console_ns.model( - "EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")} - ) - ) + @console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @console_ns.response( 200, "Endpoint enabled successfully", @@ -257,13 +231,12 @@ class EndpointEnableApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) - args = parser.parse_args() - - endpoint_id = args["endpoint_id"] + args = EndpointIdPayload.model_validate(console_ns.payload) return { - "success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) + "success": EndpointService.enable_endpoint( + tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id + ) } @@ -271,11 +244,7 @@ class EndpointEnableApi(Resource): class EndpointDisableApi(Resource): @console_ns.doc("disable_endpoint") @console_ns.doc(description="Disable a plugin endpoint") - @console_ns.expect( - console_ns.model( - "EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")} - ) - ) + @console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @console_ns.response( 200, "Endpoint disabled successfully", @@ -289,11 +258,10 @@ class EndpointDisableApi(Resource): def post(self): user, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) - args = parser.parse_args() - - endpoint_id = args["endpoint_id"] + args = EndpointIdPayload.model_validate(console_ns.payload) return { - "success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) + "success": EndpointService.disable_endpoint( + tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id + ) } diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index f72d247398..0142e14fb0 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -58,26 +58,15 @@ class OwnerTransferPayload(BaseModel): token: str -console_ns.schema_model( - MemberInvitePayload.__name__, - MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - MemberRoleUpdatePayload.__name__, - MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - OwnerTransferEmailPayload.__name__, - OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - OwnerTransferCheckPayload.__name__, - OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - OwnerTransferPayload.__name__, - OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(MemberInvitePayload) +reg(MemberRoleUpdatePayload) +reg(OwnerTransferEmailPayload) +reg(OwnerTransferCheckPayload) +reg(OwnerTransferPayload) @console_ns.route("/workspaces/current/members") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index d40748d5e3..7bada2fa12 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -75,44 +75,18 @@ class ParserPreferredProviderType(BaseModel): preferred_provider_type: Literal["system", "custom"] -console_ns.schema_model( - ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -console_ns.schema_model( - ParserCredentialId.__name__, - ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ParserCredentialCreate.__name__, - ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserCredentialUpdate.__name__, - ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserCredentialDelete.__name__, - ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserCredentialSwitch.__name__, - ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserCredentialValidate.__name__, - ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserPreferredProviderType.__name__, - ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +reg(ParserModelList) +reg(ParserCredentialId) +reg(ParserCredentialCreate) +reg(ParserCredentialUpdate) +reg(ParserCredentialDelete) +reg(ParserCredentialSwitch) +reg(ParserCredentialValidate) +reg(ParserPreferredProviderType) @console_ns.route("/workspaces/current/model-providers") diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index c820a8d1f2..246a869291 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -32,25 +32,11 @@ class ParserPostDefault(BaseModel): model_settings: list[Inner] -console_ns.schema_model( - ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - - class ParserDeleteModels(BaseModel): model: str model_type: ModelType -console_ns.schema_model( - ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - - class LoadBalancingPayload(BaseModel): configs: list[dict[str, Any]] | None = None enabled: bool | None = None @@ -119,33 +105,19 @@ class ParserParameter(BaseModel): model: str -console_ns.schema_model( - ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -console_ns.schema_model( - ParserGetCredentials.__name__, - ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ParserCreateCredential.__name__, - ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserUpdateCredential.__name__, - ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserDeleteCredential.__name__, - ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +reg(ParserGetDefault) +reg(ParserPostDefault) +reg(ParserDeleteModels) +reg(ParserPostModels) +reg(ParserGetCredentials) +reg(ParserCreateCredential) +reg(ParserUpdateCredential) +reg(ParserDeleteCredential) +reg(ParserParameter) @console_ns.route("/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 7e08ea55f9..c5624e0fc2 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -22,6 +22,10 @@ from services.plugin.plugin_service import PluginService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + @console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @setup_required @@ -46,9 +50,7 @@ class ParserList(BaseModel): page_size: int = Field(default=256) -console_ns.schema_model( - ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +reg(ParserList) @console_ns.route("/workspaces/current/plugin/list") @@ -72,11 +74,6 @@ class ParserLatest(BaseModel): plugin_ids: list[str] -console_ns.schema_model( - ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - - class ParserIcon(BaseModel): tenant_id: str filename: str @@ -173,72 +170,22 @@ class ParserReadme(BaseModel): language: str = Field(default="en-US") -console_ns.schema_model( - ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserPluginIdentifiers.__name__, - ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserPluginIdentifierQuery.__name__, - ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserMarketplaceUpgrade.__name__, - ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - -console_ns.schema_model( - ParserPermissionChange.__name__, - ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserDynamicOptions.__name__, - ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserPreferencesChange.__name__, - ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserExcludePlugin.__name__, - ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +reg(ParserLatest) +reg(ParserIcon) +reg(ParserAsset) +reg(ParserGithubUpload) +reg(ParserPluginIdentifiers) +reg(ParserGithubInstall) +reg(ParserPluginIdentifierQuery) +reg(ParserTasks) +reg(ParserMarketplaceUpgrade) +reg(ParserGithubUpgrade) +reg(ParserUninstall) +reg(ParserPermissionChange) +reg(ParserDynamicOptions) +reg(ParserPreferencesChange) +reg(ParserExcludePlugin) +reg(ParserReadme) @console_ns.route("/workspaces/current/plugin/list/latest-versions") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 9b76cb7a9c..909a5ce201 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -54,25 +54,14 @@ class WorkspaceInfoPayload(BaseModel): name: str -console_ns.schema_model( - WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -console_ns.schema_model( - SwitchWorkspacePayload.__name__, - SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - WorkspaceCustomConfigPayload.__name__, - WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - WorkspaceInfoPayload.__name__, - WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +reg(WorkspaceListQuery) +reg(SwitchWorkspacePayload) +reg(WorkspaceCustomConfigPayload) +reg(WorkspaceInfoPayload) provider_fields = { "provider_name": fields.String,