refactor: port reqparse to Pydantic model (#28913)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-11-30 16:09:42 +09:00 committed by GitHub
parent bb096f4ae3
commit 247069c7e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1013 additions and 1369 deletions

View File

@ -1,16 +1,23 @@
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
class AdvancedPromptTemplateQuery(BaseModel):
app_mode: str = Field(..., description="Application mode")
model_mode: str = Field(..., description="Model mode")
has_context: str = Field(default="true", description="Whether has context")
model_name: str = Field(..., description="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@ -18,7 +25,7 @@ parser = (
class AdvancedPromptTemplateList(Resource):
@console_ns.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
@console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
args = parser.parse_args()
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args)
return AdvancedPromptTemplateService.get_prompt(args.model_dump())

View File

@ -1,9 +1,12 @@
import uuid
from typing import Literal
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, abort
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -36,6 +39,130 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
@ -147,22 +274,7 @@ app_pagination_model = console_ns.model(
class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@console_ns.expect(console_ns.models[AppListQuery.__name__])
@console_ns.response(200, "Success", app_pagination_model)
@setup_required
@login_required
@ -172,42 +284,12 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
args = parser.parse_args()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args_dict = args.model_dump()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -254,19 +336,7 @@ class AppListApi(Resource):
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(
console_ns.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
@console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@ -279,22 +349,10 @@ class AppListApi(Resource):
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
args = CreateAppPayload.model_validate(console_ns.payload)
app_service = AppService()
app = app_service.create_app(current_tenant_id, args, current_user)
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
return app, 201
@ -326,20 +384,7 @@ class AppApi(Resource):
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@ -351,28 +396,18 @@ class AppApi(Resource):
@marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
args = UpdateAppPayload.model_validate(console_ns.payload)
app_service = AppService()
args_dict: AppService.ArgsDict = {
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
"max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
@ -401,18 +436,7 @@ class AppCopyApi(Resource):
@console_ns.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(
console_ns.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -426,15 +450,7 @@ class AppCopyApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = AppDslService(session)
@ -443,11 +459,11 @@ class AppCopyApi(Resource):
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
)
session.commit()
@ -462,11 +478,7 @@ class AppExportApi(Resource):
@console_ns.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(
console_ns.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
@console_ns.response(
200,
"App exported successfully",
@ -480,30 +492,23 @@ class AppExportApi(Resource):
@edit_permission_required
def get(self, app_model):
"""Export app"""
# Add include_secret params
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
}
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.response(200, "Name availability checked")
@setup_required
@login_required
@ -512,10 +517,10 @@ class AppNameApi(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = parser.parse_args()
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_name(app_model, args["name"])
app_model = app_service.update_app_name(app_model, args.name)
return app_model
@ -525,16 +530,7 @@ class AppIconApi(Resource):
@console_ns.doc("update_app_icon")
@console_ns.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
@console_ns.response(200, "Icon updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -544,15 +540,10 @@ class AppIconApi(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
return app_model
@ -562,11 +553,7 @@ class AppSiteStatus(Resource):
@console_ns.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
@console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -576,11 +563,10 @@ class AppSiteStatus(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
app_model = app_service.update_app_site_status(app_model, args.enable_site)
return app_model
@ -590,11 +576,7 @@ class AppApiStatus(Resource):
@console_ns.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
@console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -604,11 +586,10 @@ class AppApiStatus(Resource):
@get_app_model
@marshal_with(app_detail_model)
def post(self, app_model):
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
app_model = app_service.update_app_api_status(app_model, args.enable_api)
return app_model
@ -631,15 +612,7 @@ class AppTraceApi(Resource):
@console_ns.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
@console_ns.response(200, "Trace configuration updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -648,17 +621,12 @@ class AppTraceApi(Resource):
@edit_permission_required
def post(self, app_id):
# add app trace
parser = (
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
args = AppTracePayload.model_validate(console_ns.payload)
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
enabled=args.enabled,
tracing_provider=args.tracing_provider,
)
return {"result": "success"}

View File

@ -1,7 +1,9 @@
import logging
from typing import Any, Literal
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
@console_ns.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
@console_ns.response(200, "Completion generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App not found")
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = args["response_mode"] != "blocking"
streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
try:
@ -137,21 +154,7 @@ class ChatMessageApi(Resource):
@console_ns.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
@console_ns.response(200, "Chat message generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App or conversation not found")
@ -161,20 +164,10 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
args_model = ChatMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = args["response_mode"] != "blocking"
streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)

View File

@ -1,7 +1,9 @@
from typing import Literal
import sqlalchemy as sa
from flask import abort
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound
@ -14,13 +16,54 @@ from extensions.ext_database import db
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString, TimestampField
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -283,22 +326,7 @@ class CompletionConversationApi(Resource):
@console_ns.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
@console_ns.response(200, "Success", conversation_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -309,32 +337,17 @@ class CompletionConversationApi(Resource):
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
if args["keyword"]:
if args.keyword:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args['keyword']}%"),
Message.answer.ilike(f"%{args['keyword']}%"),
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
)
)
@ -342,7 +355,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -354,11 +367,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
if args["annotation_status"] == "annotated":
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
@ -367,7 +380,7 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations
@ -419,31 +432,7 @@ class ChatConversationApi(Resource):
@console_ns.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -454,31 +443,7 @@ class ChatConversationApi(Resource):
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
@ -490,8 +455,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"
if args.keyword:
keyword_filter = f"%{args.keyword}%"
query = (
query.join(
Message,
@ -514,12 +479,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args["sort_by"]:
match args.sort_by:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
@ -527,35 +492,35 @@ class ChatConversationApi(Resource):
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]:
match args.sort_by:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated":
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if args["message_count_gte"] and args["message_count_gte"] >= 1:
if args.message_count_gte and args.message_count_gte >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
.having(func.count(Message.id) >= args.message_count_gte)
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]:
match args.sort_by:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
@ -567,7 +532,7 @@ class ChatConversationApi(Resource):
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -14,6 +16,18 @@ from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
@console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required
@login_required
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
# NOTE: This is a temporary solution to avoid performance issues.
page = 1

View File

@ -1,6 +1,8 @@
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.app.error import (
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate")
class RuleGenerateApi(Resource):
@console_ns.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM")
@console_ns.expect(
console_ns.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
@console_ns.response(200, "Rule configuration generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
class RuleCodeGenerateApi(Resource):
@console_ns.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM")
@console_ns.expect(
console_ns.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
@console_ns.response(200, "Code rules generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
class RuleStructuredOutputGenerateApi(Resource):
@console_ns.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM")
@console_ns.expect(
console_ns.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
@console_ns.response(200, "Structured output generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
instruction=args.instruction,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
class InstructionGenerateApi(Resource):
@console_ns.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect(
console_ns.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
@console_ns.response(200, "Instruction generated successfully")
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
(p for p in providers if p.is_accept_language(args.language)), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400
return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]]
node = [node for node in nodes if node["id"] == args.node_id]
if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400
return {"error": f"node {args.node_id} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
if args.node_id == "" and args.current != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
flow_id=args.flow_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
if args.node_id != "" and args.current != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
flow_id=args.flow_id,
node_id=args.node_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
workflow_service=WorkflowService(),
)
return {"error": "incompatible parameters"}, 400
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template")
@console_ns.expect(
console_ns.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
@console_ns.response(200, "Template retrieved successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
args = InstructionTemplatePayload.model_validate(console_ns.payload)
match args.type:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args['type']}")
raise ValueError(f"Invalid type: {args.type}")

View File

@ -1,7 +1,9 @@
import logging
from typing import Literal
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -157,12 +220,7 @@ class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found")
@login_required
@ -172,27 +230,21 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
conversation = (
db.session.query(Conversation)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args["first_id"]:
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
)
@ -207,7 +259,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.limit(args.limit)
.all()
)
else:
@ -215,12 +267,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.limit(args.limit)
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]:
if len(history_messages) == args.limit:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
@ -238,7 +290,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions")
@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource):
def post(self, app_model):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args["message_id"])
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback
if not args["rating"] and feedback:
if not args.rating and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
elif args.rating and feedback:
feedback.rating = args.rating
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
rating=rating_value,
from_source="admin",
from_account_id=current_user.id,
)
@ -369,24 +411,12 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(feedback_export_parser)
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@ -395,7 +425,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = feedback_export_parser.parse_args()
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function
from services.feedback_service import FeedbackService
@ -403,12 +433,12 @@ class MessageFeedbackExportApi(Resource):
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.get("from_source"),
rating=args.get("rating"),
has_comment=args.get("has_comment"),
start_date=args.get("start_date"),
end_date=args.get("end_date"),
format_type=args.get("format", "csv"),
from_source=args.from_source,
rating=args.rating,
has_comment=args.has_comment,
start_date=args.start_date,
end_date=args.end_date,
format_type=args.format,
)
return export_data

View File

@ -1,8 +1,9 @@
from decimal import Decimal
import sqlalchemy as sa
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from flask import abort, jsonify, request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString, convert_datetime_to_date
from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource):
@console_ns.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily message statistics retrieved successfully",
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -57,7 +69,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -81,19 +93,12 @@ WHERE
return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@console_ns.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily conversation statistics retrieved successfully",
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -121,7 +126,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
@console_ns.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily terminal statistics retrieved successfully",
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -177,7 +182,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
@console_ns.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily token cost statistics retrieved successfully",
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -235,7 +240,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
@console_ns.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Average session interaction statistics retrieved successfully",
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
@ -302,7 +307,7 @@ FROM
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
@console_ns.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"User satisfaction rate statistics retrieved successfully",
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
@ -374,7 +379,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
@console_ns.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Average response time statistics retrieved successfully",
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -436,7 +441,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
@console_ns.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Tokens per second statistics retrieved successfully",
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -495,7 +500,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))

View File

@ -1,10 +1,11 @@
import json
import logging
from collections.abc import Sequence
from typing import cast
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
features: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
class BaseWorkflowRunPayload(BaseModel):
files: list[dict[str, Any]] | None = None
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any] | None = None
query: str = ""
conversation_id: str | None = None
parent_message_id: str | None = None
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class IterationNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class LoopNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
query: str = ""
class PublishWorkflowPayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable.
@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration")
@console_ns.expect(
console_ns.model(
"SyncDraftWorkflowRequest",
{
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
"features": fields.Raw(required=True, description="Workflow features configuration"),
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
@console_ns.response(
200,
"Draft workflow synced successfully",
@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
payload_data = request.get_json(silent=True)
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
elif "text/plain" in content_type:
try:
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
payload_data = json.loads(request.data.decode("utf-8"))
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService()
try:
@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.doc("run_advanced_chat_draft_workflow")
@console_ns.doc(description="Run draft workflow for advanced chat application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AdvancedChatWorkflowRunRequest",
{
"query": fields.String(required=True, description="User query"),
"inputs": fields.Raw(description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
"conversation_id": fields.String(description="Conversation ID"),
},
)
)
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
@console_ns.response(200, "Workflow run started successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied")
@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(
console_ns.model(
"IterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
@console_ns.response(200, "Iteration node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_iteration(
@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_workflow_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(
console_ns.model(
"WorkflowIterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
@console_ns.response(200, "Workflow iteration node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_iteration(
@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(
console_ns.model(
"LoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
@console_ns.response(200, "Loop node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_loop(
@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_workflow_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(
console_ns.model(
"WorkflowLoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
@console_ns.response(200, "Workflow loop node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_loop(
@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow")
@console_ns.doc(description="Run draft workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"DraftWorkflowRunRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
},
)
)
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@console_ns.response(200, "Draft workflow run started successfully")
@console_ns.response(403, "Permission denied")
@setup_required
@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
@console_ns.doc("run_draft_workflow_node")
@console_ns.doc(description="Run draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(
console_ns.model(
"DraftWorkflowNodeRunRequest",
{
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
Run draft workflow node
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
user_inputs = args.get("inputs")
user_inputs = args_model.inputs
if user_inputs is None:
raise ValueError("missing inputs")
@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource):
@console_ns.doc("get_published_workflow")
@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
@console_ns.expect(parser_publish)
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with Session(db.engine) as session:
@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource):
@console_ns.doc("get_default_block_config")
@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@console_ns.response(200, "Default block configuration retrieved successfully")
@console_ns.response(404, "Block type not found")
@console_ns.expect(parser_block)
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
args = parser_block.parse_args()
q = args.get("q")
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
filters = None
if q:
if args.q:
try:
filters = json.loads(args.get("q", ""))
filters = json.loads(args.q)
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource):
@console_ns.expect(parser_convert)
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
@console_ns.doc("convert_to_workflow")
@console_ns.doc(description="Convert application to workflow mode")
@console_ns.doc(params={"app_id": "Application ID"})
@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
if request.data:
args = parser_convert.parse_args()
else:
args = {}
payload = console_ns.payload or {}
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
# convert to workflow mode
workflow_service = WorkflowService()
@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
}
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@console_ns.expect(parser_workflows)
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.doc("get_all_published_workflows")
@console_ns.doc(description="Get all published workflows for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = parser_workflows.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
user_id = args.user_id
named_only = args.named_only
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService()
with Session(db.engine) as session:
@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")
@console_ns.doc(description="Update workflow by ID")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@console_ns.expect(
console_ns.model(
"UpdateWorkflowRequest",
{
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
@console_ns.response(200, "Workflow updated successfully", workflow_model)
@console_ns.response(404, "Workflow not found")
@console_ns.response(403, "Permission denied")
@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource):
Update workflow attributes
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if args.marked_name is not None:
update_data["marked_name"] = args.marked_name
if args.marked_comment is not None:
update_data["marked_comment"] = args.marked_comment
if not update_data:
return {"message": "No valid fields to update"}, 400
@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"node_id", type=str, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_id = args["node_id"]
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
node_id = args.node_id
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.doc("draft_workflow_trigger_run_all")
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
@console_ns.response(200, "Workflow executed successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(500, "Internal server error")
@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"node_ids", type=list, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_ids = args["node_ids"]
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
node_ids = args.node_ids
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:

View File

@ -1,6 +1,9 @@
from datetime import datetime
from dateutil.parser import isoparse
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -14,6 +17,48 @@ from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required
@login_required
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
"""
Get workflow app logs
"""
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

View File

@ -1,7 +1,8 @@
from typing import cast
from typing import Literal, cast
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
class WorkflowRunListQuery(BaseModel):
last_id: str | None = Field(default=None, description="Last run ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
return parser.parse_args()
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required
@login_required
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
"""
Get advanced chat app workflow run list
"""
args = _parse_workflow_run_list_args()
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
"""
Get advanced chat workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
"""
Get workflow run list
"""
args = _parse_workflow_run_list_args()
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
"""
Get workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)

View File

@ -1,5 +1,6 @@
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from flask import abort, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
@console_ns.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model
@setup_required
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
@console_ns.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model
@setup_required
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
@console_ns.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model
@setup_required
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@console_ns.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required
@login_required
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))

View File

@ -174,63 +174,25 @@ class CheckEmailUniquePayload(BaseModel):
return email(value)
console_ns.schema_model(
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountInterfaceLanguagePayload.__name__,
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountInterfaceThemePayload.__name__,
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountTimezonePayload.__name__,
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountPasswordPayload.__name__,
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletePayload.__name__,
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletionFeedbackPayload.__name__,
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationActivatePayload.__name__,
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationAutocompleteQuery.__name__,
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailSendPayload.__name__,
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailValidityPayload.__name__,
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailResetPayload.__name__,
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
CheckEmailUniquePayload.__name__,
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AccountInitPayload)
reg(AccountNamePayload)
reg(AccountAvatarPayload)
reg(AccountInterfaceLanguagePayload)
reg(AccountInterfaceThemePayload)
reg(AccountTimezonePayload)
reg(AccountPasswordPayload)
reg(AccountDeletePayload)
reg(AccountDeletionFeedbackPayload)
reg(EducationActivatePayload)
reg(EducationAutocompleteQuery)
reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
@console_ns.route("/account/init")

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@ -7,21 +11,49 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EndpointCreatePayload(BaseModel):
plugin_unique_identifier: str
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointIdPayload(BaseModel):
endpoint_id: str
class EndpointUpdatePayload(EndpointIdPayload):
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointListQuery(BaseModel):
page: int = Field(ge=1)
page_size: int = Field(gt=0)
class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@console_ns.doc("create_endpoint")
@console_ns.doc(description="Create a new plugin endpoint")
@console_ns.expect(
console_ns.model(
"EndpointCreateRequest",
{
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
"settings": fields.Raw(required=True, description="Endpoint settings"),
"name": fields.String(required=True, description="Endpoint name"),
},
)
)
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
@console_ns.response(
200,
"Endpoint created successfully",
@ -35,26 +67,16 @@ class EndpointCreateApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
)
}
except PluginPermissionDeniedError as e:
@ -65,11 +87,7 @@ class EndpointCreateApi(Resource):
class EndpointListApi(Resource):
@console_ns.doc("list_endpoints")
@console_ns.doc(description="List plugin endpoints with pagination")
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
)
@console_ns.expect(console_ns.models[EndpointListQuery.__name__])
@console_ns.response(
200,
"Success",
@ -83,15 +101,10 @@ class EndpointListApi(Resource):
def get(self):
user, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args["page"]
page_size = args["page_size"]
page = args.page
page_size = args.page_size
return jsonable_encoder(
{
@ -109,12 +122,7 @@ class EndpointListApi(Resource):
class EndpointListForSinglePluginApi(Resource):
@console_ns.doc("list_plugin_endpoints")
@console_ns.doc(description="List endpoints for a specific plugin")
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
)
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
@console_ns.response(
200,
"Success",
@ -128,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource):
def get(self):
user, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args()
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args["page"]
page_size = args["page_size"]
plugin_id = args["plugin_id"]
page = args.page
page_size = args.page_size
plugin_id = args.plugin_id
return jsonable_encoder(
{
@ -157,11 +159,7 @@ class EndpointListForSinglePluginApi(Resource):
class EndpointDeleteApi(Resource):
@console_ns.doc("delete_endpoint")
@console_ns.doc(description="Delete a plugin endpoint")
@console_ns.expect(
console_ns.model(
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.response(
200,
"Endpoint deleted successfully",
@ -175,13 +173,12 @@ class EndpointDeleteApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}
@ -189,16 +186,7 @@ class EndpointDeleteApi(Resource):
class EndpointUpdateApi(Resource):
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect(
console_ns.model(
"EndpointUpdateRequest",
{
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
"settings": fields.Raw(required=True, description="Updated settings"),
"name": fields.String(required=True, description="Updated name"),
},
)
)
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
@console_ns.response(
200,
"Endpoint updated successfully",
@ -212,25 +200,15 @@ class EndpointUpdateApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("endpoint_id", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
endpoint_id=args.endpoint_id,
name=args.name,
settings=args.settings,
)
}
@ -239,11 +217,7 @@ class EndpointUpdateApi(Resource):
class EndpointEnableApi(Resource):
@console_ns.doc("enable_endpoint")
@console_ns.doc(description="Enable a plugin endpoint")
@console_ns.expect(
console_ns.model(
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.response(
200,
"Endpoint enabled successfully",
@ -257,13 +231,12 @@ class EndpointEnableApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
"success": EndpointService.enable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}
@ -271,11 +244,7 @@ class EndpointEnableApi(Resource):
class EndpointDisableApi(Resource):
@console_ns.doc("disable_endpoint")
@console_ns.doc(description="Disable a plugin endpoint")
@console_ns.expect(
console_ns.model(
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.response(
200,
"Endpoint disabled successfully",
@ -289,11 +258,10 @@ class EndpointDisableApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
"success": EndpointService.disable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}

View File

@ -58,26 +58,15 @@ class OwnerTransferPayload(BaseModel):
token: str
console_ns.schema_model(
MemberInvitePayload.__name__,
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
MemberRoleUpdatePayload.__name__,
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferEmailPayload.__name__,
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferCheckPayload.__name__,
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferPayload.__name__,
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(MemberInvitePayload)
reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
@console_ns.route("/workspaces/current/members")

View File

@ -75,44 +75,18 @@ class ParserPreferredProviderType(BaseModel):
preferred_provider_type: Literal["system", "custom"]
console_ns.schema_model(
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserCredentialId.__name__,
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialCreate.__name__,
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialUpdate.__name__,
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialDelete.__name__,
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialSwitch.__name__,
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialValidate.__name__,
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferredProviderType.__name__,
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
reg(ParserModelList)
reg(ParserCredentialId)
reg(ParserCredentialCreate)
reg(ParserCredentialUpdate)
reg(ParserCredentialDelete)
reg(ParserCredentialSwitch)
reg(ParserCredentialValidate)
reg(ParserPreferredProviderType)
@console_ns.route("/workspaces/current/model-providers")

View File

@ -32,25 +32,11 @@ class ParserPostDefault(BaseModel):
model_settings: list[Inner]
console_ns.schema_model(
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserDeleteModels(BaseModel):
model: str
model_type: ModelType
console_ns.schema_model(
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class LoadBalancingPayload(BaseModel):
configs: list[dict[str, Any]] | None = None
enabled: bool | None = None
@ -119,33 +105,19 @@ class ParserParameter(BaseModel):
model: str
console_ns.schema_model(
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserGetCredentials.__name__,
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCreateCredential.__name__,
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserUpdateCredential.__name__,
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDeleteCredential.__name__,
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
reg(ParserGetDefault)
reg(ParserPostDefault)
reg(ParserDeleteModels)
reg(ParserPostModels)
reg(ParserGetCredentials)
reg(ParserCreateCredential)
reg(ParserUpdateCredential)
reg(ParserDeleteCredential)
reg(ParserParameter)
@console_ns.route("/workspaces/current/default-model")

View File

@ -22,6 +22,10 @@ from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@ -46,9 +50,7 @@ class ParserList(BaseModel):
page_size: int = Field(default=256)
console_ns.schema_model(
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
reg(ParserList)
@console_ns.route("/workspaces/current/plugin/list")
@ -72,11 +74,6 @@ class ParserLatest(BaseModel):
plugin_ids: list[str]
console_ns.schema_model(
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserIcon(BaseModel):
tenant_id: str
filename: str
@ -173,72 +170,22 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US")
console_ns.schema_model(
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifiers.__name__,
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifierQuery.__name__,
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserMarketplaceUpgrade.__name__,
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPermissionChange.__name__,
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDynamicOptions.__name__,
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferencesChange.__name__,
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserExcludePlugin.__name__,
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
reg(ParserLatest)
reg(ParserIcon)
reg(ParserAsset)
reg(ParserGithubUpload)
reg(ParserPluginIdentifiers)
reg(ParserGithubInstall)
reg(ParserPluginIdentifierQuery)
reg(ParserTasks)
reg(ParserMarketplaceUpgrade)
reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
@console_ns.route("/workspaces/current/plugin/list/latest-versions")

View File

@ -54,25 +54,14 @@ class WorkspaceInfoPayload(BaseModel):
name: str
console_ns.schema_model(
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
SwitchWorkspacePayload.__name__,
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceCustomConfigPayload.__name__,
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceInfoPayload.__name__,
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
provider_fields = {
"provider_name": fields.String,