Merge branch 'main' into feat/memory-orchestration-fed

This commit is contained in:
zxhlyh 2025-12-02 10:59:54 +08:00
commit 4e037d14d1
98 changed files with 25608 additions and 1689 deletions

226
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,226 @@
# CODEOWNERS
# This file defines code ownership for the Dify project.
# Each line is a file pattern followed by one or more owners.
# Owners can be @username, @org/team-name, or email addresses.
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
* @crazywoola @laipz8200 @Yeuoly
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200
# Frontend
web/ @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh

View File

@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
ext_commands, ext_commands,
ext_compress, ext_compress,
ext_database, ext_database,
ext_forward_refs,
ext_hosting_provider, ext_hosting_provider,
ext_import_modules, ext_import_modules,
ext_logging, ext_logging,
@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
ext_warnings, ext_warnings,
ext_import_modules, ext_import_modules,
ext_orjson, ext_orjson,
ext_forward_refs,
ext_set_secretkey, ext_set_secretkey,
ext_compress, ext_compress,
ext_code_based_extension, ext_code_based_extension,

View File

@ -1,16 +1,23 @@
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser() class AdvancedPromptTemplateQuery(BaseModel):
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode") app_mode: str = Field(..., description="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode") model_mode: str = Field(..., description="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context") has_context: str = Field(default="true", description="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name") model_name: str = Field(..., description="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
) )
@ -18,7 +25,7 @@ parser = (
class AdvancedPromptTemplateList(Resource): class AdvancedPromptTemplateList(Resource):
@console_ns.doc("get_advanced_prompt_templates") @console_ns.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration") @console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
@console_ns.response( @console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
) )
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
args = parser.parse_args() args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args) return AdvancedPromptTemplateService.get_prompt(args.model_dump())

View File

@ -1,9 +1,12 @@
import uuid import uuid
from typing import Literal
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, abort from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -36,6 +39,130 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first # Register base models first
@ -147,22 +274,7 @@ app_pagination_model = console_ns.model(
class AppListApi(Resource): class AppListApi(Resource):
@console_ns.doc("list_apps") @console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect( @console_ns.expect(console_ns.models[AppListQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@console_ns.response(200, "Success", app_pagination_model) @console_ns.response(200, "Success", app_pagination_model)
@setup_required @setup_required
@login_required @login_required
@ -172,42 +284,12 @@ class AppListApi(Resource):
"""Get app list""" """Get app list"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value): args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try: args_dict = args.model_dump()
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -254,19 +336,7 @@ class AppListApi(Resource):
@console_ns.doc("create_app") @console_ns.doc("create_app")
@console_ns.doc(description="Create a new application") @console_ns.doc(description="Create a new application")
@console_ns.expect( @console_ns.expect(console_ns.models[CreateAppPayload.__name__])
console_ns.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App created successfully", app_detail_model) @console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@ -279,22 +349,10 @@ class AppListApi(Resource):
def post(self): def post(self):
"""Create app""" """Create app"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( args = CreateAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
app = app_service.create_app(current_tenant_id, args, current_user) app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
return app, 201 return app, 201
@ -326,20 +384,7 @@ class AppApi(Resource):
@console_ns.doc("update_app") @console_ns.doc("update_app")
@console_ns.doc(description="Update application details") @console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
console_ns.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@console_ns.response(200, "App updated successfully", app_detail_with_site_model) @console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@ -351,28 +396,18 @@ class AppApi(Resource):
@marshal_with(app_detail_with_site_model) @marshal_with(app_detail_with_site_model)
def put(self, app_model): def put(self, app_model):
"""Update app""" """Update app"""
parser = ( args = UpdateAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
args_dict: AppService.ArgsDict = { args_dict: AppService.ArgsDict = {
"name": args["name"], "name": args.name,
"description": args.get("description", ""), "description": args.description or "",
"icon_type": args.get("icon_type", ""), "icon_type": args.icon_type or "",
"icon": args.get("icon", ""), "icon": args.icon or "",
"icon_background": args.get("icon_background", ""), "icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
"max_active_requests": args.get("max_active_requests", 0), "max_active_requests": args.max_active_requests or 0,
} }
app_model = app_service.update_app(app_model, args_dict) app_model = app_service.update_app(app_model, args_dict)
@ -401,18 +436,7 @@ class AppCopyApi(Resource):
@console_ns.doc("copy_app") @console_ns.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect( @console_ns.expect(console_ns.models[CopyAppPayload.__name__])
console_ns.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App copied successfully", app_detail_with_site_model) @console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -426,15 +450,7 @@ class AppCopyApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = CopyAppPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
@ -443,11 +459,11 @@ class AppCopyApi(Resource):
account=current_user, account=current_user,
import_mode=ImportMode.YAML_CONTENT, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.get("name"), name=args.name,
description=args.get("description"), description=args.description,
icon_type=args.get("icon_type"), icon_type=args.icon_type,
icon=args.get("icon"), icon=args.icon,
icon_background=args.get("icon_background"), icon_background=args.icon_background,
) )
session.commit() session.commit()
@ -462,11 +478,7 @@ class AppExportApi(Resource):
@console_ns.doc("export_app") @console_ns.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppExportQuery.__name__])
console_ns.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@console_ns.response( @console_ns.response(
200, 200,
"App exported successfully", "App exported successfully",
@ -480,30 +492,23 @@ class AppExportApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# Add include_secret params args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
return { return {
"data": AppDslService.export_dsl( "data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id") app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
) )
} }
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name") @console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource): class AppNameApi(Resource):
@console_ns.doc("check_app_name") @console_ns.doc("check_app_name")
@console_ns.doc(description="Check if app name is available") @console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.response(200, "Name availability checked") @console_ns.response(200, "Name availability checked")
@setup_required @setup_required
@login_required @login_required
@ -512,10 +517,10 @@ class AppNameApi(Resource):
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
args = parser.parse_args() args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_name(app_model, args["name"]) app_model = app_service.update_app_name(app_model, args.name)
return app_model return app_model
@ -525,16 +530,7 @@ class AppIconApi(Resource):
@console_ns.doc("update_app_icon") @console_ns.doc("update_app_icon")
@console_ns.doc(description="Update application icon") @console_ns.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppIconPayload.__name__])
console_ns.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(200, "Icon updated successfully") @console_ns.response(200, "Icon updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -544,15 +540,10 @@ class AppIconApi(Resource):
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
parser = ( args = AppIconPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
return app_model return app_model
@ -562,11 +553,7 @@ class AppSiteStatus(Resource):
@console_ns.doc("update_app_site_status") @console_ns.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site") @console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
console_ns.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@console_ns.response(200, "Site status updated successfully", app_detail_model) @console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -576,11 +563,10 @@ class AppSiteStatus(Resource):
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json") args = AppSiteStatusPayload.model_validate(console_ns.payload)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args["enable_site"]) app_model = app_service.update_app_site_status(app_model, args.enable_site)
return app_model return app_model
@ -590,11 +576,7 @@ class AppApiStatus(Resource):
@console_ns.doc("update_app_api_status") @console_ns.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API") @console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
console_ns.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@console_ns.response(200, "API status updated successfully", app_detail_model) @console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -604,11 +586,10 @@ class AppApiStatus(Resource):
@get_app_model @get_app_model
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json") args = AppApiStatusPayload.model_validate(console_ns.payload)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args["enable_api"]) app_model = app_service.update_app_api_status(app_model, args.enable_api)
return app_model return app_model
@ -631,15 +612,7 @@ class AppTraceApi(Resource):
@console_ns.doc("update_app_trace") @console_ns.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration") @console_ns.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppTracePayload.__name__])
console_ns.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@console_ns.response(200, "Trace configuration updated successfully") @console_ns.response(200, "Trace configuration updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -648,17 +621,12 @@ class AppTraceApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
# add app trace # add app trace
parser = ( args = AppTracePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(
app_id=app_id, app_id=app_id,
enabled=args["enabled"], enabled=args.enabled,
tracing_provider=args["tracing_provider"], tracing_provider=args.tracing_provider,
) )
return {"result": "success"} return {"result": "success"}

View File

@ -1,7 +1,9 @@
import logging import logging
from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user # define completion message api for user
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
@console_ns.doc("create_completion_message") @console_ns.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging") @console_ns.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
console_ns.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Completion generated successfully") @console_ns.response(200, "Completion generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = ( args_model = CompletionMessagePayload.model_validate(console_ns.payload)
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True, by_alias=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
try: try:
@ -137,21 +154,7 @@ class ChatMessageApi(Resource):
@console_ns.doc("create_chat_message") @console_ns.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging") @console_ns.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
console_ns.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Chat message generated successfully") @console_ns.response(200, "Chat message generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App or conversation not found") @console_ns.response(404, "App or conversation not found")
@ -161,20 +164,10 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
parser = ( args_model = ChatMessagePayload.model_validate(console_ns.payload)
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True, by_alias=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)

View File

@ -1,7 +1,9 @@
from typing import Literal
import sqlalchemy as sa import sqlalchemy as sa
from flask import abort from flask import abort, request
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from flask_restx.inputs import int_range from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -14,13 +16,54 @@ from extensions.ext_database import db
from fields.conversation_fields import MessageTextField from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString, TimestampField from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@ -283,22 +326,7 @@ class CompletionConversationApi(Resource):
@console_ns.doc("list_completion_conversations") @console_ns.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@console_ns.response(200, "Success", conversation_pagination_model) @console_ns.response(200, "Success", conversation_pagination_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -309,32 +337,17 @@ class CompletionConversationApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
query = sa.select(Conversation).where( query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
) )
if args["keyword"]: if args.keyword:
query = query.join(Message, Message.conversation_id == Conversation.id).where( query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_( or_(
Message.query.ilike(f"%{args['keyword']}%"), Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args['keyword']}%"), Message.answer.ilike(f"%{args.keyword}%"),
) )
) )
@ -342,7 +355,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -354,11 +367,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc) query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file # FIXME, the type ignore in this file
if args["annotation_status"] == "annotated": if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args.annotation_status == "not_annotated":
query = ( query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
@ -367,7 +380,7 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations return conversations
@ -419,31 +432,7 @@ class ChatConversationApi(Resource):
@console_ns.doc("list_chat_conversations") @console_ns.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@console_ns.response(200, "Success", conversation_with_summary_pagination_model) @console_ns.response(200, "Success", conversation_with_summary_pagination_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -454,31 +443,7 @@ class ChatConversationApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
subquery = ( subquery = (
db.session.query( db.session.query(
@ -490,8 +455,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]: if args.keyword:
keyword_filter = f"%{args['keyword']}%" keyword_filter = f"%{args.keyword}%"
query = ( query = (
query.join( query.join(
Message, Message,
@ -514,12 +479,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
if start_datetime_utc: if start_datetime_utc:
match args["sort_by"]: match args.sort_by:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc) query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
@ -527,35 +492,35 @@ class ChatConversationApi(Resource):
if end_datetime_utc: if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59) end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]: match args.sort_by:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc) query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc) query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated": if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args.annotation_status == "not_annotated":
query = ( query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0) .having(func.count(MessageAnnotation.id) == 0)
) )
if args["message_count_gte"] and args["message_count_gte"] >= 1: if args.message_count_gte and args.message_count_gte >= 1:
query = ( query = (
query.options(joinedload(Conversation.messages)) # type: ignore query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id) .join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"]) .having(func.count(Message.id) >= args.message_count_gte)
) )
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]: match args.sort_by:
case "created_at": case "created_at":
query = query.order_by(Conversation.created_at.asc()) query = query.order_by(Conversation.created_at.asc())
case "-created_at": case "-created_at":
@ -567,7 +532,7 @@ class ChatConversationApi(Resource):
case _: case _:
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations return conversations

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -14,6 +16,18 @@ from libs.login import login_required
from models import ConversationVariable from models import ConversationVariable
from models.model import AppMode from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first # Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
@console_ns.doc("get_conversation_variables") @console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
console_ns.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required @setup_required
@login_required @login_required
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT) @get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model) @marshal_with(paginated_conversation_variable_model)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args") args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
stmt = ( stmt = (
select(ConversationVariable) select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id) .where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at) .order_by(ConversationVariable.created_at)
) )
if args["conversation_id"]: stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues. # NOTE: This is a temporary solution to avoid performance issues.
page = 1 page = 1

View File

@ -1,6 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate") @console_ns.route("/rule-generate")
class RuleGenerateApi(Resource): class RuleGenerateApi(Resource):
@console_ns.doc("generate_rule_config") @console_ns.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM") @console_ns.doc(description="Generate rule configuration using LLM")
@console_ns.expect( @console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
console_ns.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@console_ns.response(200, "Rule configuration generated successfully") @console_ns.response(200, "Rule configuration generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = RuleGeneratePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
no_variable=args["no_variable"], no_variable=args.no_variable,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
class RuleCodeGenerateApi(Resource): class RuleCodeGenerateApi(Resource):
@console_ns.doc("generate_rule_code") @console_ns.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM") @console_ns.doc(description="Generate code rules using LLM")
@console_ns.expect( @console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
console_ns.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@console_ns.response(200, "Code rules generated successfully") @console_ns.response(200, "Code rules generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
code_language=args["code_language"], code_language=args.code_language,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
class RuleStructuredOutputGenerateApi(Resource): class RuleStructuredOutputGenerateApi(Resource):
@console_ns.doc("generate_structured_output") @console_ns.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM") @console_ns.doc(description="Generate structured output rules using LLM")
@console_ns.expect( @console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
console_ns.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@console_ns.response(200, "Structured output generated successfully") @console_ns.response(200, "Structured output generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
class InstructionGenerateApi(Resource): class InstructionGenerateApi(Resource):
@console_ns.doc("generate_instruction") @console_ns.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use") @console_ns.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect( @console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
console_ns.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Instruction generated successfully") @console_ns.response(200, "Instruction generated successfully")
@console_ns.response(400, "Invalid request parameters or flow/workflow not found") @console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = InstructionGeneratePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next( code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None (p for p in providers if p.is_accept_language(args.language)), None
) )
code_template = code_provider.get_default_code() if code_provider else "" code_template = code_provider.get_default_code() if code_provider else ""
try: try:
# Generate from nothing for a workflow node # Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args["flow_id"]).first() app = db.session.query(App).where(App.id == args.flow_id).first()
if not app: if not app:
return {"error": f"app {args['flow_id']} not found"}, 400 return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow: if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400 return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"] nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]] node = [node for node in nodes if node["id"] == args.node_id]
if len(node) == 0: if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400 return {"error": f"node {args.node_id} not found"}, 400
node_type = node[0]["data"]["type"] node_type = node[0]["data"]["type"]
match node_type: match node_type:
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
no_variable=True, no_variable=True,
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
no_variable=True, no_variable=True,
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
code_language=args["language"], code_language=args.language,
) )
case _: case _:
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow if args.node_id == "" and args.current != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy( return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args.flow_id,
current=args["current"], current=args.current,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
ideal_output=args["ideal_output"], ideal_output=args.ideal_output,
) )
if args["node_id"] != "" and args["current"] != "": # For workflow node if args.node_id != "" and args.current != "": # For workflow node
return LLMGenerator.instruction_modify_workflow( return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args.flow_id,
node_id=args["node_id"], node_id=args.node_id,
current=args["current"], current=args.current,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
ideal_output=args["ideal_output"], ideal_output=args.ideal_output,
workflow_service=WorkflowService(), workflow_service=WorkflowService(),
) )
return {"error": "incompatible parameters"}, 400 return {"error": "incompatible parameters"}, 400
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
class InstructionGenerationTemplateApi(Resource): class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template") @console_ns.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template") @console_ns.doc(description="Get instruction generation template")
@console_ns.expect( @console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
console_ns.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Template retrieved successfully") @console_ns.response(200, "Template retrieved successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json") args = InstructionTemplatePayload.model_validate(console_ns.payload)
args = parser.parse_args() match args.type:
match args["type"]:
case "prompt": case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _: case _:
raise ValueError(f"Invalid type: {args['type']}") raise ValueError(f"Invalid type: {args.type}")

View File

@ -1,7 +1,9 @@
import logging import logging
from typing import Literal
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@ -157,12 +220,7 @@ class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages") @console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model) @console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@login_required @login_required
@ -172,27 +230,21 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model) @marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
parser = ( args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if args["first_id"]: if args.first_id:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first() .first()
) )
@ -207,7 +259,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id, Message.id != first_message.id,
) )
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args.limit)
.all() .all()
) )
else: else:
@ -215,12 +267,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args.limit)
.all() .all()
) )
# Initialize has_more based on whether we have a full page # Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]: if len(history_messages) == args.limit:
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page # Check if there are more messages before the current page
has_more = db.session.scalar( has_more = db.session.scalar(
@ -238,7 +290,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages)) history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks") @console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback") @console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)") @console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
console_ns.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@console_ns.response(200, "Feedback updated successfully") @console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found") @console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource):
def post(self, app_model): def post(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = MessageFeedbackPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
message_id = str(args["message_id"]) message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback feedback = message.admin_feedback
if not args["rating"] and feedback: if not args.rating and feedback:
db.session.delete(feedback) db.session.delete(feedback)
elif args["rating"] and feedback: elif args.rating and feedback:
feedback.rating = args["rating"] feedback.rating = args.rating
elif not args["rating"] and not feedback: elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists") raise ValueError("rating cannot be None when feedback not exists")
else: else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback( feedback = MessageFeedback(
app_id=app_model.id, app_id=app_model.id,
conversation_id=message.conversation_id, conversation_id=message.conversation_id,
message_id=message.id, message_id=message.id,
rating=args["rating"], rating=rating_value,
from_source="admin", from_source="admin",
from_account_id=current_user.id, from_account_id=current_user.id,
) )
@ -369,24 +411,12 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions} return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export") @console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource): class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks") @console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets") @console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(feedback_export_parser) @console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.response(200, "Feedback data exported successfully") @console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@ -395,7 +425,7 @@ class MessageFeedbackExportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
args = feedback_export_parser.parse_args() args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function # Import the service function
from services.feedback_service import FeedbackService from services.feedback_service import FeedbackService
@ -403,12 +433,12 @@ class MessageFeedbackExportApi(Resource):
try: try:
export_data = FeedbackService.export_feedbacks( export_data = FeedbackService.export_feedbacks(
app_id=app_model.id, app_id=app_model.id,
from_source=args.get("from_source"), from_source=args.from_source,
rating=args.get("rating"), rating=args.rating,
has_comment=args.get("has_comment"), has_comment=args.has_comment,
start_date=args.get("start_date"), start_date=args.start_date,
end_date=args.get("end_date"), end_date=args.end_date,
format_type=args.get("format", "csv"), format_type=args.format,
) )
return export_data return export_data

View File

@ -1,8 +1,9 @@
from decimal import Decimal from decimal import Decimal
import sqlalchemy as sa import sqlalchemy as sa
from flask import abort, jsonify from flask import abort, jsonify, request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString, convert_datetime_to_date from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode from models import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource): class DailyMessageStatistic(Resource):
@console_ns.doc("get_daily_message_statistics") @console_ns.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application") @console_ns.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
console_ns.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.response( @console_ns.response(
200, 200,
"Daily message statistics retrieved successfully", "Daily message statistics retrieved successfully",
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -57,7 +69,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -81,19 +93,12 @@ WHERE
return jsonify({"data": response_data}) return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@console_ns.doc("get_daily_conversation_statistics") @console_ns.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application") @console_ns.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Daily conversation statistics retrieved successfully", "Daily conversation statistics retrieved successfully",
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -121,7 +126,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
@console_ns.doc("get_daily_terminals_statistics") @console_ns.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application") @console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Daily terminal statistics retrieved successfully", "Daily terminal statistics retrieved successfully",
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -177,7 +182,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
@console_ns.doc("get_daily_token_cost_statistics") @console_ns.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application") @console_ns.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Daily token cost statistics retrieved successfully", "Daily token cost statistics retrieved successfully",
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -235,7 +240,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
@console_ns.doc("get_average_session_interaction_statistics") @console_ns.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application") @console_ns.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Average session interaction statistics retrieved successfully", "Average session interaction statistics retrieved successfully",
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("c.created_at") converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -302,7 +307,7 @@ FROM
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
@console_ns.doc("get_user_satisfaction_rate_statistics") @console_ns.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application") @console_ns.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"User satisfaction rate statistics retrieved successfully", "User satisfaction rate statistics retrieved successfully",
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("m.created_at") converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -374,7 +379,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
@console_ns.doc("get_average_response_time_statistics") @console_ns.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application") @console_ns.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Average response time statistics retrieved successfully", "Average response time statistics retrieved successfully",
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -436,7 +441,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
@console_ns.doc("get_tokens_per_second_statistics") @console_ns.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application") @console_ns.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Tokens per second statistics retrieved successfully", "Tokens per second statistics retrieved successfully",
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -495,7 +500,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))

View File

@ -1,10 +1,11 @@
import json import json
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from typing import cast from typing import Any
from flask import abort, request from flask import abort, request
from flask_restx import Resource, fields, inputs, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000 LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields) workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
features: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
class BaseWorkflowRunPayload(BaseModel):
files: list[dict[str, Any]] | None = None
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any] | None = None
query: str = ""
conversation_id: str | None = None
parent_message_id: str | None = None
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class IterationNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class LoopNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
query: str = ""
class PublishWorkflowPayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation # at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable. # of concerns and make the code more maintainable.
@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow") @console_ns.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration") @console_ns.doc(description="Sync draft workflow configuration")
@console_ns.expect( @console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
console_ns.model(
"SyncDraftWorkflowRequest",
{
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
"features": fields.Raw(required=True, description="Workflow features configuration"),
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Draft workflow synced successfully", "Draft workflow synced successfully",
@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type: if "application/json" in content_type:
parser = ( payload_data = request.get_json(silent=True)
reqparse.RequestParser() if not isinstance(payload_data, dict):
.add_argument("graph", type=dict, required=True, nullable=False, location="json") return {"message": "Invalid JSON data"}, 400
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
data = json.loads(request.data.decode("utf-8")) payload_data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
except json.JSONDecodeError: except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else: else:
abort(415) abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.doc("run_advanced_chat_draft_workflow") @console_ns.doc("run_advanced_chat_draft_workflow")
@console_ns.doc(description="Run draft workflow for advanced chat application") @console_ns.doc(description="Run draft workflow for advanced chat application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
console_ns.model(
"AdvancedChatWorkflowRunRequest",
{
"query": fields.String(required=True, description="User query"),
"inputs": fields.Raw(description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
"conversation_id": fields.String(description="Conversation ID"),
},
)
)
@console_ns.response(200, "Workflow run started successfully") @console_ns.response(200, "Workflow run started successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_iteration_node") @console_ns.doc("run_advanced_chat_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for advanced chat") @console_ns.doc(description="Run draft workflow iteration node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
console_ns.model(
"IterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Iteration node run started successfully") @console_ns.response(200, "Iteration node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_workflow_draft_iteration_node") @console_ns.doc("run_workflow_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node") @console_ns.doc(description="Run draft workflow iteration node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
console_ns.model(
"WorkflowIterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow iteration node run started successfully") @console_ns.response(200, "Workflow iteration node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_loop_node") @console_ns.doc("run_advanced_chat_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for advanced chat") @console_ns.doc(description="Run draft workflow loop node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
console_ns.model(
"LoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Loop node run started successfully") @console_ns.response(200, "Loop node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_loop( response = AppGenerateService.generate_single_loop(
@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_workflow_draft_loop_node") @console_ns.doc("run_workflow_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node") @console_ns.doc(description="Run draft workflow loop node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
console_ns.model(
"WorkflowLoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow loop node run started successfully") @console_ns.response(200, "Workflow loop node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_loop( response = AppGenerateService.generate_single_loop(
@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow") @console_ns.doc("run_draft_workflow")
@console_ns.doc(description="Run draft workflow") @console_ns.doc(description="Run draft workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
console_ns.model(
"DraftWorkflowRunRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
},
)
)
@console_ns.response(200, "Draft workflow run started successfully") @console_ns.response(200, "Draft workflow run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
@console_ns.doc("run_draft_workflow_node") @console_ns.doc("run_draft_workflow_node")
@console_ns.doc(description="Run draft workflow node") @console_ns.doc(description="Run draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
console_ns.model(
"DraftWorkflowNodeRunRequest",
{
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model) @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
Run draft workflow node Run draft workflow node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
user_inputs = args.get("inputs") user_inputs = args_model.inputs
if user_inputs is None: if user_inputs is None:
raise ValueError("missing inputs") raise ValueError("missing inputs")
@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish") @console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource): class PublishedWorkflowApi(Resource):
@console_ns.doc("get_published_workflow") @console_ns.doc("get_published_workflow")
@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None # return workflow, if not found, return None
return workflow return workflow
@console_ns.expect(parser_publish) @console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_publish.parse_args() args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource): class DefaultBlockConfigApi(Resource):
@console_ns.doc("get_default_block_config") @console_ns.doc("get_default_block_config")
@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"}) @console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@console_ns.response(200, "Default block configuration retrieved successfully") @console_ns.response(200, "Default block configuration retrieved successfully")
@console_ns.response(404, "Block type not found") @console_ns.response(404, "Block type not found")
@console_ns.expect(parser_block) @console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
args = parser_block.parse_args() args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
q = args.get("q")
filters = None filters = None
if q: if args.q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(args.q)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")
@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters) return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow") @console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource): class ConvertToWorkflowApi(Resource):
@console_ns.expect(parser_convert) @console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
@console_ns.doc("convert_to_workflow") @console_ns.doc("convert_to_workflow")
@console_ns.doc(description="Convert application to workflow mode") @console_ns.doc(description="Convert application to workflow mode")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if request.data: payload = console_ns.payload or {}
args = parser_convert.parse_args() args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
else:
args = {}
# convert to workflow mode # convert to workflow mode
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
} }
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows") @console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource): class PublishedAllWorkflowApi(Resource):
@console_ns.expect(parser_workflows) @console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.doc("get_all_published_workflows") @console_ns.doc("get_all_published_workflows")
@console_ns.doc(description="Get all published workflows for an application") @console_ns.doc(description="Get all published workflows for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_workflows.parse_args() args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args["page"] page = args.page
limit = args["limit"] limit = args.limit
user_id = args.get("user_id") user_id = args.user_id
named_only = args.get("named_only", False) named_only = args.named_only
if user_id: if user_id:
if user_id != current_user.id: if user_id != current_user.id:
raise Forbidden() raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id") @console_ns.doc("update_workflow_by_id")
@console_ns.doc(description="Update workflow by ID") @console_ns.doc(description="Update workflow by ID")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
console_ns.model(
"UpdateWorkflowRequest",
{
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response(200, "Workflow updated successfully", workflow_model) @console_ns.response(200, "Workflow updated successfully", workflow_model)
@console_ns.response(404, "Workflow not found") @console_ns.response(404, "Workflow not found")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource):
Update workflow attributes Update workflow attributes
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data # Prepare update data
update_data = {} update_data = {}
if args.get("marked_name") is not None: if args.marked_name is not None:
update_data["marked_name"] = args["marked_name"] update_data["marked_name"] = args.marked_name
if args.get("marked_comment") is not None: if args.marked_comment is not None:
update_data["marked_comment"] = args["marked_comment"] update_data["marked_comment"] = args.marked_comment
if not update_data: if not update_data:
return {"message": "No valid fields to update"}, 400 return {"message": "No valid fields to update"}, 400
@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives Poll for trigger events and execute full workflow when event arrives
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument( args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
"node_id", type=str, required=True, location="json", nullable=False node_id = args.node_id
)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model) draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow: if not draft_workflow:
@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.doc("draft_workflow_trigger_run_all") @console_ns.doc("draft_workflow_trigger_run_all")
@console_ns.doc(description="Full workflow debug when the start node is a trigger") @console_ns.doc(description="Full workflow debug when the start node is a trigger")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
console_ns.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@console_ns.response(200, "Workflow executed successfully") @console_ns.response(200, "Workflow executed successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument( args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
"node_ids", type=list, required=True, location="json", nullable=False node_ids = args.node_ids
)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model) draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow: if not draft_workflow:

View File

@ -1,6 +1,9 @@
from datetime import datetime
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask_restx import Resource, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import console_ns
@ -14,6 +17,48 @@ from models import App
from models.model import AppMode from models.model import AppMode
from services.workflow_app_service import WorkflowAppService from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc("get_workflow_app_logs") @console_ns.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required @setup_required
@login_required @login_required
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
""" """
Get workflow app logs Get workflow app logs
""" """
parser = ( args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

View File

@ -1,7 +1,8 @@
from typing import cast from typing import Literal, cast
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
) )
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns: class WorkflowRunListQuery(BaseModel):
Parsed arguments containing last_id, limit, status, and triggered_from filters last_id: str | None = Field(default=None, description="Last run ID for pagination")
""" limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
parser = ( status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
reqparse.RequestParser() default=None, description="Workflow run status filter"
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
) )
return parser.parse_args() triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
) )
return parser.parse_args()
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc( @console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model) @console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required @setup_required
@login_required @login_required
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
""" """
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
args = _parse_workflow_run_list_args() args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
""" """
Get advanced chat workflow runs count statistics Get advanced chat workflow runs count statistics
""" """
args = _parse_workflow_run_count_args() args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
args = _parse_workflow_run_list_args() args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
""" """
Get workflow runs count statistics Get workflow runs count statistics
""" """
args = _parse_workflow_run_count_args() args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )

View File

@ -1,5 +1,6 @@
from flask import abort, jsonify from flask import abort, jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns from controllers.console import console_ns
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
@console_ns.doc("get_workflow_daily_runs_statistic") @console_ns.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics") @console_ns.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily runs statistics retrieved successfully") @console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
@console_ns.doc("get_workflow_daily_terminals_statistic") @console_ns.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics") @console_ns.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily terminals statistics retrieved successfully") @console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
@console_ns.doc("get_workflow_daily_token_cost_statistic") @console_ns.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics") @console_ns.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily token cost statistics retrieved successfully") @console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@console_ns.doc("get_workflow_average_app_interaction_statistic") @console_ns.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics") @console_ns.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Average app interaction statistics retrieved successfully") @console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))

View File

@ -58,7 +58,7 @@ class VersionApi(Resource):
response = httpx.get( response = httpx.get(
check_update_url, check_update_url,
params={"current_version": args["current_version"]}, params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10), timeout=httpx.Timeout(timeout=10.0, connect=3.0),
) )
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))

View File

@ -174,63 +174,25 @@ class CheckEmailUniquePayload(BaseModel):
return email(value) return email(value)
console_ns.schema_model( def reg(cls: type[BaseModel]):
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
)
console_ns.schema_model(
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) reg(AccountInitPayload)
) reg(AccountNamePayload)
console_ns.schema_model( reg(AccountAvatarPayload)
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) reg(AccountInterfaceLanguagePayload)
) reg(AccountInterfaceThemePayload)
console_ns.schema_model( reg(AccountTimezonePayload)
AccountInterfaceLanguagePayload.__name__, reg(AccountPasswordPayload)
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(AccountDeletePayload)
) reg(AccountDeletionFeedbackPayload)
console_ns.schema_model( reg(EducationActivatePayload)
AccountInterfaceThemePayload.__name__, reg(EducationAutocompleteQuery)
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(ChangeEmailSendPayload)
) reg(ChangeEmailValidityPayload)
console_ns.schema_model( reg(ChangeEmailResetPayload)
AccountTimezonePayload.__name__, reg(CheckEmailUniquePayload)
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountPasswordPayload.__name__,
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletePayload.__name__,
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletionFeedbackPayload.__name__,
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationActivatePayload.__name__,
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationAutocompleteQuery.__name__,
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailSendPayload.__name__,
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailValidityPayload.__name__,
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailResetPayload.__name__,
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
CheckEmailUniquePayload.__name__,
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/account/init") @console_ns.route("/account/init")

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@ -7,21 +11,49 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService from services.plugin.endpoint_service import EndpointService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EndpointCreatePayload(BaseModel):
plugin_unique_identifier: str
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointIdPayload(BaseModel):
endpoint_id: str
class EndpointUpdatePayload(EndpointIdPayload):
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointListQuery(BaseModel):
page: int = Field(ge=1)
page_size: int = Field(gt=0)
class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
@console_ns.route("/workspaces/current/endpoints/create") @console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource): class EndpointCreateApi(Resource):
@console_ns.doc("create_endpoint") @console_ns.doc("create_endpoint")
@console_ns.doc(description="Create a new plugin endpoint") @console_ns.doc(description="Create a new plugin endpoint")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
console_ns.model(
"EndpointCreateRequest",
{
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
"settings": fields.Raw(required=True, description="Endpoint settings"),
"name": fields.String(required=True, description="Endpoint name"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint created successfully", "Endpoint created successfully",
@ -35,26 +67,16 @@ class EndpointCreateApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = ( args = EndpointCreatePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
try: try:
return { return {
"success": EndpointService.create_endpoint( "success": EndpointService.create_endpoint(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier, plugin_unique_identifier=args.plugin_unique_identifier,
name=name, name=args.name,
settings=settings, settings=args.settings,
) )
} }
except PluginPermissionDeniedError as e: except PluginPermissionDeniedError as e:
@ -65,11 +87,7 @@ class EndpointCreateApi(Resource):
class EndpointListApi(Resource): class EndpointListApi(Resource):
@console_ns.doc("list_endpoints") @console_ns.doc("list_endpoints")
@console_ns.doc(description="List plugin endpoints with pagination") @console_ns.doc(description="List plugin endpoints with pagination")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointListQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
)
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -83,15 +101,10 @@ class EndpointListApi(Resource):
def get(self): def get(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = ( args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
page = args["page"] page = args.page
page_size = args["page_size"] page_size = args.page_size
return jsonable_encoder( return jsonable_encoder(
{ {
@ -109,12 +122,7 @@ class EndpointListApi(Resource):
class EndpointListForSinglePluginApi(Resource): class EndpointListForSinglePluginApi(Resource):
@console_ns.doc("list_plugin_endpoints") @console_ns.doc("list_plugin_endpoints")
@console_ns.doc(description="List endpoints for a specific plugin") @console_ns.doc(description="List endpoints for a specific plugin")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
)
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -128,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource):
def get(self): def get(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = ( args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args()
page = args["page"] page = args.page
page_size = args["page_size"] page_size = args.page_size
plugin_id = args["plugin_id"] plugin_id = args.plugin_id
return jsonable_encoder( return jsonable_encoder(
{ {
@ -157,11 +159,7 @@ class EndpointListForSinglePluginApi(Resource):
class EndpointDeleteApi(Resource): class EndpointDeleteApi(Resource):
@console_ns.doc("delete_endpoint") @console_ns.doc("delete_endpoint")
@console_ns.doc(description="Delete a plugin endpoint") @console_ns.doc(description="Delete a plugin endpoint")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
console_ns.model(
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint deleted successfully", "Endpoint deleted successfully",
@ -175,13 +173,12 @@ class EndpointDeleteApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = EndpointIdPayload.model_validate(console_ns.payload)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
return { return {
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) "success": EndpointService.delete_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
} }
@ -189,16 +186,7 @@ class EndpointDeleteApi(Resource):
class EndpointUpdateApi(Resource): class EndpointUpdateApi(Resource):
@console_ns.doc("update_endpoint") @console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint") @console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
console_ns.model(
"EndpointUpdateRequest",
{
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
"settings": fields.Raw(required=True, description="Updated settings"),
"name": fields.String(required=True, description="Updated name"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint updated successfully", "Endpoint updated successfully",
@ -212,25 +200,15 @@ class EndpointUpdateApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = ( args = EndpointUpdatePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("endpoint_id", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
return { return {
"success": EndpointService.update_endpoint( "success": EndpointService.update_endpoint(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=user.id,
endpoint_id=endpoint_id, endpoint_id=args.endpoint_id,
name=name, name=args.name,
settings=settings, settings=args.settings,
) )
} }
@ -239,11 +217,7 @@ class EndpointUpdateApi(Resource):
class EndpointEnableApi(Resource): class EndpointEnableApi(Resource):
@console_ns.doc("enable_endpoint") @console_ns.doc("enable_endpoint")
@console_ns.doc(description="Enable a plugin endpoint") @console_ns.doc(description="Enable a plugin endpoint")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
console_ns.model(
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint enabled successfully", "Endpoint enabled successfully",
@ -257,13 +231,12 @@ class EndpointEnableApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = EndpointIdPayload.model_validate(console_ns.payload)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
return { return {
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) "success": EndpointService.enable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
} }
@ -271,11 +244,7 @@ class EndpointEnableApi(Resource):
class EndpointDisableApi(Resource): class EndpointDisableApi(Resource):
@console_ns.doc("disable_endpoint") @console_ns.doc("disable_endpoint")
@console_ns.doc(description="Disable a plugin endpoint") @console_ns.doc(description="Disable a plugin endpoint")
@console_ns.expect( @console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
console_ns.model(
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Endpoint disabled successfully", "Endpoint disabled successfully",
@ -289,11 +258,10 @@ class EndpointDisableApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = EndpointIdPayload.model_validate(console_ns.payload)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
return { return {
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) "success": EndpointService.disable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
} }

View File

@ -58,26 +58,15 @@ class OwnerTransferPayload(BaseModel):
token: str token: str
console_ns.schema_model( def reg(cls: type[BaseModel]):
MemberInvitePayload.__name__, console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model( reg(MemberInvitePayload)
MemberRoleUpdatePayload.__name__, reg(MemberRoleUpdatePayload)
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(OwnerTransferEmailPayload)
) reg(OwnerTransferCheckPayload)
console_ns.schema_model( reg(OwnerTransferPayload)
OwnerTransferEmailPayload.__name__,
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferCheckPayload.__name__,
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferPayload.__name__,
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/members") @console_ns.route("/workspaces/current/members")

View File

@ -75,44 +75,18 @@ class ParserPreferredProviderType(BaseModel):
preferred_provider_type: Literal["system", "custom"] preferred_provider_type: Literal["system", "custom"]
console_ns.schema_model( def reg(cls: type[BaseModel]):
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
)
console_ns.schema_model(
ParserCredentialId.__name__,
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model( reg(ParserModelList)
ParserCredentialCreate.__name__, reg(ParserCredentialId)
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(ParserCredentialCreate)
) reg(ParserCredentialUpdate)
reg(ParserCredentialDelete)
console_ns.schema_model( reg(ParserCredentialSwitch)
ParserCredentialUpdate.__name__, reg(ParserCredentialValidate)
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(ParserPreferredProviderType)
)
console_ns.schema_model(
ParserCredentialDelete.__name__,
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialSwitch.__name__,
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialValidate.__name__,
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferredProviderType.__name__,
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/model-providers") @console_ns.route("/workspaces/current/model-providers")

View File

@ -32,25 +32,11 @@ class ParserPostDefault(BaseModel):
model_settings: list[Inner] model_settings: list[Inner]
console_ns.schema_model(
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserDeleteModels(BaseModel): class ParserDeleteModels(BaseModel):
model: str model: str
model_type: ModelType model_type: ModelType
console_ns.schema_model(
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class LoadBalancingPayload(BaseModel): class LoadBalancingPayload(BaseModel):
configs: list[dict[str, Any]] | None = None configs: list[dict[str, Any]] | None = None
enabled: bool | None = None enabled: bool | None = None
@ -119,33 +105,19 @@ class ParserParameter(BaseModel):
model: str model: str
console_ns.schema_model( def reg(cls: type[BaseModel]):
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
)
console_ns.schema_model(
ParserGetCredentials.__name__,
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model( reg(ParserGetDefault)
ParserCreateCredential.__name__, reg(ParserPostDefault)
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(ParserDeleteModels)
) reg(ParserPostModels)
reg(ParserGetCredentials)
console_ns.schema_model( reg(ParserCreateCredential)
ParserUpdateCredential.__name__, reg(ParserUpdateCredential)
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(ParserDeleteCredential)
) reg(ParserParameter)
console_ns.schema_model(
ParserDeleteCredential.__name__,
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/default-model") @console_ns.route("/workspaces/current/default-model")

View File

@ -22,6 +22,10 @@ from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/workspaces/current/plugin/debugging-key") @console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource): class PluginDebuggingKeyApi(Resource):
@setup_required @setup_required
@ -46,9 +50,7 @@ class ParserList(BaseModel):
page_size: int = Field(default=256) page_size: int = Field(default=256)
console_ns.schema_model( reg(ParserList)
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list") @console_ns.route("/workspaces/current/plugin/list")
@ -72,11 +74,6 @@ class ParserLatest(BaseModel):
plugin_ids: list[str] plugin_ids: list[str]
console_ns.schema_model(
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserIcon(BaseModel): class ParserIcon(BaseModel):
tenant_id: str tenant_id: str
filename: str filename: str
@ -173,72 +170,22 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US") language: str = Field(default="en-US")
console_ns.schema_model( reg(ParserLatest)
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) reg(ParserIcon)
) reg(ParserAsset)
reg(ParserGithubUpload)
console_ns.schema_model( reg(ParserPluginIdentifiers)
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) reg(ParserGithubInstall)
) reg(ParserPluginIdentifierQuery)
reg(ParserTasks)
console_ns.schema_model( reg(ParserMarketplaceUpgrade)
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) reg(ParserGithubUpgrade)
) reg(ParserUninstall)
reg(ParserPermissionChange)
console_ns.schema_model( reg(ParserDynamicOptions)
ParserPluginIdentifiers.__name__, reg(ParserPreferencesChange)
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), reg(ParserExcludePlugin)
) reg(ParserReadme)
console_ns.schema_model(
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifierQuery.__name__,
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserMarketplaceUpgrade.__name__,
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPermissionChange.__name__,
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDynamicOptions.__name__,
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferencesChange.__name__,
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserExcludePlugin.__name__,
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list/latest-versions") @console_ns.route("/workspaces/current/plugin/list/latest-versions")

View File

@ -54,25 +54,14 @@ class WorkspaceInfoPayload(BaseModel):
name: str name: str
console_ns.schema_model( def reg(cls: type[BaseModel]):
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
)
console_ns.schema_model(
SwitchWorkspacePayload.__name__,
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceCustomConfigPayload.__name__,
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceInfoPayload.__name__,
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
provider_fields = { provider_fields = {
"provider_name": fields.String, "provider_name": fields.String,

View File

@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
class InvokeFrom(StrEnum): class InvokeFrom(StrEnum):
""" """
@ -275,10 +275,8 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
start_node_id: str | None = None start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild() AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild() EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild() ConversationAppGenerateEntity.model_rebuild()

View File

@ -58,11 +58,39 @@ class OceanBaseVector(BaseVector):
password=self._config.password, password=self._config.password,
db_name=self._config.database, db_name=self._config.database,
) )
self._fields: list[str] = [] # List of fields in the collection
if self._client.check_table_exists(collection_name):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.OCEANBASE return VectorType.OCEANBASE
def _load_collection_fields(self):
"""
Load collection fields from the database table.
This method populates the _fields list with column names from the table.
"""
try:
if self._collection_name in self._client.metadata_obj.tables:
table = self._client.metadata_obj.tables[self._collection_name]
# Store all column names except 'id' (primary key)
self._fields = [column.name for column in table.columns if column.name != "id"]
logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields)
else:
logger.warning("Collection '%s' not found in metadata", self._collection_name)
except Exception as e:
logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e))
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
:param field: Field name to check
:return: True if field exists, False otherwise
"""
return field in self._fields
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0]) self._vec_dim = len(embeddings[0])
self._create_collection() self._create_collection()
@ -151,6 +179,7 @@ class OceanBaseVector(BaseVector):
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name) logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
self._client.refresh_metadata([self._collection_name]) self._client.refresh_metadata([self._collection_name])
self._load_collection_fields()
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool: def _check_hybrid_search_support(self) -> bool:
@ -177,42 +206,134 @@ class OceanBaseVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents) ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings): for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert( try:
table_name=self._collection_name, self._client.insert(
data={ table_name=self._collection_name,
"id": id, data={
"vector": emb, "id": id,
"text": doc.page_content, "vector": emb,
"metadata": doc.metadata, "text": doc.page_content,
}, "metadata": doc.metadata,
) },
)
except Exception as e:
logger.exception(
"Failed to insert document with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to insert document with id '{id}'") from e
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, ids=id) try:
return bool(cur.rowcount != 0) cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
except Exception as e:
logger.exception(
"Failed to check if text exists with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to check text existence for id '{id}'") from e
def delete_by_ids(self, ids: list[str]): def delete_by_ids(self, ids: list[str]):
if not ids: if not ids:
return return
self._client.delete(table_name=self._collection_name, ids=ids) try:
self._client.delete(table_name=self._collection_name, ids=ids)
logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name)
except Exception as e:
logger.exception(
"Failed to delete %d documents from collection '%s'",
len(ids),
self._collection_name,
)
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
from sqlalchemy import text try:
import re
cur = self._client.get( from sqlalchemy import text
table_name=self._collection_name,
ids=None, # Validate key to prevent injection in JSON path
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")], if not re.match(r"^[a-zA-Z0-9_.]+$", key):
output_column_name=["id"], raise ValueError(f"Invalid characters in metadata key: {key}")
)
return [row[0] for row in cur] # Use parameterized query to prevent SQL injection
sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value")
with self._client.engine.connect() as conn:
result = conn.execute(sql, {"value": value})
ids = [row[0] for row in result]
logger.debug(
"Found %d documents with metadata field '%s'='%s' in collection '%s'",
len(ids),
key,
value,
self._collection_name,
)
return ids
except Exception as e:
logger.exception(
"Failed to get IDs by metadata field '%s'='%s' in collection '%s'",
key,
value,
self._collection_name,
)
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value) ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids) if ids:
self.delete_by_ids(ids)
else:
logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value)
def _process_search_results(
self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score"
) -> list[Document]:
"""
Common method to process search results
:param results: Search results as list of tuples (text, metadata, score)
:param score_threshold: Score threshold for filtering
:param score_key: Key name for score in metadata
:return: List of documents
"""
docs = []
for row in results:
text, metadata_str, score = row[0], row[1], row[2]
# Parse metadata JSON
try:
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
# Add score to metadata
metadata[score_key] = score
# Filter by score threshold
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
if not self._hybrid_search_enabled: if not self._hybrid_search_enabled:
logger.warning(
"Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)."
)
return []
if not self.field_exists("text"):
logger.warning(
"Full-text search unavailable: collection '%s' missing 'text' field; "
"recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.",
self._collection_name,
)
return [] return []
try: try:
@ -220,13 +341,24 @@ class OceanBaseVector(BaseVector):
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer") raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter") score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score # Build parameterized query to prevent SQL injection
from sqlalchemy import text
document_ids_filter = kwargs.get("document_ids_filter")
params = {"query": query}
where_clause = ""
if document_ids_filter:
# Create parameterized placeholders for document IDs
placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter)))
where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})"
# Add document IDs to parameters
for i, doc_id in enumerate(document_ids_filter):
params[f"doc_id_{i}"] = doc_id
full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name} FROM {self._collection_name}
WHERE MATCH (text) AGAINST (:query) > 0 WHERE MATCH (text) AGAINST (:query) > 0
{where_clause} {where_clause}
@ -235,41 +367,45 @@ class OceanBaseVector(BaseVector):
with self._client.engine.connect() as conn: with self._client.engine.connect() as conn:
with conn.begin(): with conn.begin():
from sqlalchemy import text result = conn.execute(text(full_sql), params)
result = conn.execute(text(full_sql), {"query": query})
rows = result.fetchall() rows = result.fetchall()
docs = [] return self._process_search_results(rows, score_threshold=score_threshold)
for row in rows:
metadata_str, _text, score = row
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
metadata["score"] = score
docs.append(Document(page_content=_text, metadata=metadata))
return docs
except Exception as e: except Exception as e:
logger.warning("Failed to fulltext search: %s.", str(e)) logger.exception(
return [] "Failed to perform full-text search on collection '%s' with query '%s'",
self._collection_name,
query,
)
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from sqlalchemy import text
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
_where_clause = None _where_clause = None
if document_ids_filter: if document_ids_filter:
# Validate document IDs to prevent SQL injection
# Document IDs should be alphanumeric with hyphens and underscores
import re
for doc_id in document_ids_filter:
if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id):
raise ValueError(f"Invalid document ID format: {doc_id}")
# Safe to use in query after validation
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})" where_clause = f"metadata->>'$.document_id' in ({document_ids})"
from sqlalchemy import text
_where_clause = [text(where_clause)] _where_clause = [text(where_clause)]
ef_search = kwargs.get("ef_search", self._hnsw_ef_search) ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search: if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search) self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10) topk = kwargs.get("top_k", 10)
try:
score_threshold = float(val) if (val := kwargs.get("score_threshold")) is not None else 0.0
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid score_threshold parameter: {e}") from e
try: try:
cur = self._client.ann_search( cur = self._client.ann_search(
table_name=self._collection_name, table_name=self._collection_name,
@ -282,21 +418,27 @@ class OceanBaseVector(BaseVector):
where_clause=_where_clause, where_clause=_where_clause,
) )
except Exception as e: except Exception as e:
raise Exception("Failed to search by vector. ", e) logger.exception(
docs = [] "Failed to perform vector search on collection '%s'",
for _text, metadata, distance in cur: self._collection_name,
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=_text,
metadata=metadata,
)
) )
return docs raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
# Convert distance to score and prepare results for processing
results = []
for _text, metadata_str, distance in cur:
score = 1 - distance / math.sqrt(2)
results.append((_text, metadata_str, score))
return self._process_search_results(results, score_threshold=score_threshold)
def delete(self): def delete(self):
self._client.drop_table_if_exist(self._collection_name) try:
self._client.drop_table_if_exist(self._collection_name)
logger.debug("Dropped collection '%s'", self._collection_name)
except Exception as e:
logger.exception("Failed to delete collection '%s'", self._collection_name)
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
class OceanBaseVectorFactory(AbstractVectorFactory): class OceanBaseVectorFactory(AbstractVectorFactory):

View File

@ -54,6 +54,8 @@ class ToolProviderApiEntity(BaseModel):
configuration: MCPConfiguration | None = Field( configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool" default=None, description="The timeout and sse_read_timeout of the MCP tool"
) )
# Workflow
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
@field_validator("tools", mode="before") @field_validator("tools", mode="before")
@classmethod @classmethod
@ -87,6 +89,8 @@ class ToolProviderApiEntity(BaseModel):
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers)) optional_fields.update(self.optional_field("original_headers", self.original_headers))
elif self.type == ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
return { return {
"id": self.id, "id": self.id,
"author": self.author, "author": self.author,

View File

@ -1,7 +1,11 @@
import importlib
import logging import logging
import operator
import pkgutil
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4 from uuid import uuid4
@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]):
cls._node_data_type = node_data_type cls._node_data_type = node_data_type
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under core.workflow.nodes.*
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
# Only register concrete subclasses that define node_type and version()
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith("core.workflow.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
# External/test subclasses may register but must not override production
bucket.setdefault(version, cls) # type: ignore[index]
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
version_keys = [v for v in bucket if v != "latest"]
numeric_pairs: list[tuple[str, int]] = []
for v in version_keys:
numeric_pairs.append((v, int(v)))
if numeric_pairs:
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
else:
latest_key = max(version_keys) if version_keys else version
bucket["latest"] = bucket[latest_key]
@classmethod @classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
""" """
@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]):
return None return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
def __init__( def __init__(
self, self,
id: str, id: str,
@ -240,23 +275,23 @@ class Node(Generic[NodeDataT]):
from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode): if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "") start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "") start_event.provider_type = getattr(self.node_data, "provider_type", "")
from core.workflow.nodes.datasource.datasource_node import DatasourceNode from core.workflow.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode): if isinstance(self, DatasourceNode):
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "") plugin_id = getattr(self.node_data, "plugin_id", "")
provider_name = getattr(self.get_base_node_data(), "provider_name", "") provider_name = getattr(self.node_data, "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}" start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "") start_event.provider_type = getattr(self.node_data, "provider_type", "")
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
if isinstance(self, TriggerEventNode): if isinstance(self, TriggerEventNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "") start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "") start_event.provider_type = getattr(self.node_data, "provider_type", "")
from typing import cast from typing import cast
@ -265,7 +300,7 @@ class Node(Generic[NodeDataT]):
if isinstance(self, AgentNode): if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit( start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name, name=cast(AgentNodeData, self.node_data).agent_strategy_name,
icon=self.agent_strategy_icon, icon=self.agent_strategy_icon,
) )
@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]):
# in `api/core/workflow/nodes/__init__.py`. # in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.") raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.
Then we return a readonly view of the registry to avoid accidental mutation.
"""
# Import all node modules to ensure they are loaded (thus registered)
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports
# e.g. node_factory imports node_mapping which builds the mapping here.
if _modname in {
"core.workflow.nodes.node_factory",
"core.workflow.nodes.node_mapping",
}:
continue
importlib.import_module(_modname)
# Return a readonly view so callers can't mutate the registry by accident
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return False return False
@ -419,10 +477,6 @@ class Node(Generic[NodeDataT]):
"""Get the default values dictionary for this node.""" """Get the default values dictionary for this node."""
return self._node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
return self._node_data
# Public interface properties that delegate to abstract methods # Public interface properties that delegate to abstract methods
@property @property
def error_strategy(self) -> ErrorStrategy | None: def error_strategy(self) -> ErrorStrategy | None:
@ -548,7 +602,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
start_at=event.start_at, start_at=event.start_at,
inputs=event.inputs, inputs=event.inputs,
metadata=event.metadata, metadata=event.metadata,
@ -561,7 +615,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
index=event.index, index=event.index,
pre_loop_output=event.pre_loop_output, pre_loop_output=event.pre_loop_output,
) )
@ -572,7 +626,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
start_at=event.start_at, start_at=event.start_at,
inputs=event.inputs, inputs=event.inputs,
outputs=event.outputs, outputs=event.outputs,
@ -586,7 +640,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
start_at=event.start_at, start_at=event.start_at,
inputs=event.inputs, inputs=event.inputs,
outputs=event.outputs, outputs=event.outputs,
@ -601,7 +655,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
start_at=event.start_at, start_at=event.start_at,
inputs=event.inputs, inputs=event.inputs,
metadata=event.metadata, metadata=event.metadata,
@ -614,7 +668,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
index=event.index, index=event.index,
pre_iteration_output=event.pre_iteration_output, pre_iteration_output=event.pre_iteration_output,
) )
@ -625,7 +679,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
start_at=event.start_at, start_at=event.start_at,
inputs=event.inputs, inputs=event.inputs,
outputs=event.outputs, outputs=event.outputs,
@ -639,7 +693,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id, id=self._node_execution_id,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
node_title=self.get_base_node_data().title, node_title=self.node_data.title,
start_at=event.start_at, start_at=event.start_at,
inputs=event.inputs, inputs=event.inputs,
outputs=event.outputs, outputs=event.outputs,

View File

@ -1,165 +1,9 @@
from collections.abc import Mapping from collections.abc import Mapping
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.trigger_plugin import TriggerEventNode
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
LATEST_VERSION = "latest" LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. # Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
# Specifically, if you have introduced new node types, you should add them here. NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.LOOP: {
LATEST_VERSION: LoopNode,
"1": LoopNode,
},
NodeType.LOOP_START: {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,
},
NodeType.TRIGGER_PLUGIN: {
LATEST_VERSION: TriggerEventNode,
"1": TriggerEventNode,
},
NodeType.TRIGGER_SCHEDULE: {
LATEST_VERSION: TriggerScheduleNode,
"1": TriggerScheduleNode,
},
}

View File

@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import ( from core.workflow.enums import (
@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]):
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
} }
if usage.total_tokens > 0: if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]):
@staticmethod @staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool): # Avoid importing WorkflowTool at module import time; rely on duck typing
return tool_runtime.latest_usage # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
latest = getattr(tool_runtime, "latest_usage", None)
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
# for any name, so we must type-check here.
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
# Allow dict payloads from external runtimes
return LLMUsage.model_validate(latest)
# Fallback to empty usage when attribute is missing or not a valid payload
return LLMUsage.empty_usage() return LLMUsage.empty_usage()
@classmethod @classmethod

View File

@ -0,0 +1,49 @@
import logging
from dify_app import DifyApp
def is_enabled() -> bool:
return True
def init_app(app: DifyApp):
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
Safe to run multiple times.
"""
logger = logging.getLogger(__name__)
try:
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
EasyUIBasedAppGenerateEntity,
RagPipelineGenerateEntity,
WorkflowAppGenerateEntity,
)
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
ns = {"TraceQueueManager": TraceQueueManager}
for Model in (
AppGenerateEntity,
EasyUIBasedAppGenerateEntity,
ConversationAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
WorkflowAppGenerateEntity,
RagPipelineGenerateEntity,
):
try:
Model.model_rebuild(_types_namespace=ns)
except Exception as e:
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
except Exception as e:
# Don't block app startup; just log at debug level.
logger.debug("ext_forward_refs init skipped: %s", e)

View File

@ -111,7 +111,7 @@ package = false
dev = [ dev = [
"coverage~=7.2.4", "coverage~=7.2.4",
"dotenv-linter~=0.5.0", "dotenv-linter~=0.5.0",
"faker~=32.1.0", "faker~=38.2.0",
"lxml-stubs~=0.5.1", "lxml-stubs~=0.5.1",
"ty~=0.0.1a19", "ty~=0.0.1a19",
"basedpyright~=1.31.0", "basedpyright~=1.31.0",

View File

@ -201,7 +201,9 @@ class ToolTransformService:
@staticmethod @staticmethod
def workflow_provider_to_user_provider( def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None provider_controller: WorkflowToolProviderController,
labels: list[str] | None = None,
workflow_app_id: str | None = None,
): ):
""" """
convert provider controller to user provider convert provider controller to user provider
@ -221,6 +223,7 @@ class ToolTransformService:
plugin_unique_identifier=None, plugin_unique_identifier=None,
tools=[], tools=[],
labels=labels or [], labels=labels or [],
workflow_app_id=workflow_app_id,
) )
@staticmethod @staticmethod

View File

@ -189,6 +189,9 @@ class WorkflowToolManageService:
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all() ).all()
# Create a mapping from provider_id to app_id
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
tools: list[WorkflowToolProviderController] = [] tools: list[WorkflowToolProviderController] = []
for provider in db_tools: for provider in db_tools:
try: try:
@ -202,8 +205,11 @@ class WorkflowToolManageService:
result = [] result = []
for tool in tools: for tool in tools:
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool, labels=labels.get(tool.provider_id, []) provider_controller=tool,
labels=labels.get(tool.provider_id, []),
workflow_app_id=workflow_app_id,
) )
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider) ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
user_tool_provider.tools = [ user_tool_provider.tools = [

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,100 @@
"""
Unit tests for ToolProviderApiEntity workflow_app_id field.
This test suite covers:
- ToolProviderApiEntity workflow_app_id field creation and default value
- ToolProviderApiEntity.to_dict() method behavior with workflow_app_id
"""
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
class TestToolProviderApiEntityWorkflowAppId:
"""Test suite for ToolProviderApiEntity workflow_app_id field."""
def test_workflow_app_id_field_default_none(self):
"""Test that workflow_app_id defaults to None when not provided."""
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
)
assert entity.workflow_app_id is None
def test_to_dict_includes_workflow_app_id_when_workflow_type_and_has_value(self):
"""Test that to_dict() includes workflow_app_id when type is WORKFLOW and value is set."""
workflow_app_id = "app_123"
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
workflow_app_id=workflow_app_id,
)
result = entity.to_dict()
assert "workflow_app_id" in result
assert result["workflow_app_id"] == workflow_app_id
def test_to_dict_excludes_workflow_app_id_when_workflow_type_and_none(self):
"""Test that to_dict() excludes workflow_app_id when type is WORKFLOW but value is None."""
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
workflow_app_id=None,
)
result = entity.to_dict()
assert "workflow_app_id" not in result
def test_to_dict_excludes_workflow_app_id_when_not_workflow_type(self):
"""Test that to_dict() excludes workflow_app_id when type is not WORKFLOW."""
workflow_app_id = "app_123"
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.BUILT_IN,
workflow_app_id=workflow_app_id,
)
result = entity.to_dict()
assert "workflow_app_id" not in result
def test_to_dict_includes_workflow_app_id_for_workflow_type_with_empty_string(self):
"""Test that to_dict() excludes workflow_app_id when value is empty string (falsy)."""
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
workflow_app_id="",
)
result = entity.to_dict()
assert "workflow_app_id" not in result

View File

@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "test" return "1"
def __init__( def __init__(
self, self,

View File

@ -744,7 +744,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
) )
llm_node = graph.nodes["llm"] llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data() base_node_data = llm_node.node_data
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]

View File

@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock LLM node.""" """Execute mock LLM node."""
@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock agent node.""" """Execute mock agent node."""
@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock tool node.""" """Execute mock tool node."""
@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock knowledge retrieval node.""" """Execute mock knowledge retrieval node."""
@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock HTTP request node.""" """Execute mock HTTP request node."""
@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock question classifier node.""" """Execute mock question classifier node."""
@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock parameter extractor node.""" """Execute mock parameter extractor node."""
@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""Execute mock document extractor node.""" """Execute mock document extractor node."""
@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _create_graph_engine(self, index: int, item: Any): def _create_graph_engine(self, index: int, item: Any):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" """Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _create_graph_engine(self, start_at, root_node_id: str): def _create_graph_engine(self, start_at, root_node_id: str):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" """Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
"""Execute mock template transform node.""" """Execute mock template transform node."""
@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
"""Return the version of this mock node.""" """Return the version of this mock node."""
return "mock-1" return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
"""Execute mock code node.""" """Execute mock code node."""

View File

@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
type_version_set: set[tuple[NodeType, str]] = set() type_version_set: set[tuple[NodeType, str]] = set()
for cls in classes: for cls in classes:
# Only validate production node classes; skip test-defined subclasses and external helpers
module_name = getattr(cls, "__module__", "")
if not module_name.startswith("core."):
continue
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
node_type = cls.node_type node_type = cls.node_type

View File

@ -0,0 +1,84 @@
import types
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
# Import concrete nodes we will assert on (numeric version path)
from core.workflow.nodes.variable_assigner.v1.node import (
VariableAssignerNode as VariableAssignerV1,
)
from core.workflow.nodes.variable_assigner.v2.node import (
VariableAssignerNode as VariableAssignerV2,
)
def test_variable_assigner_latest_prefers_highest_numeric_version():
# Act
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert basic presence
assert NodeType.VARIABLE_ASSIGNER in mapping
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
# Both concrete versions must be present
assert va_versions.get("1") is VariableAssignerV1
assert va_versions.get("2") is VariableAssignerV2
# And latest should point to numerically-highest version ("2")
assert va_versions.get("latest") is VariableAssignerV2
def test_latest_prefers_highest_numeric_version():
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
# that has no concrete implementations in production to avoid interference.
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
def init_node_data(self, data):
pass
def _run(self):
raise NotImplementedError
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self):
return None
def _get_retry_config(self):
return types.SimpleNamespace() # not used
def _get_title(self) -> str:
return "version1"
def _get_description(self):
return None
def _get_default_value_dict(self):
return {}
def get_base_node_data(self):
return types.SimpleNamespace(title="version1")
class _Version2(_Version1): # type: ignore[misc]
@classmethod
def version(cls) -> str:
return "2"
def _get_title(self) -> str:
return "version2"
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
assert legacy_versions.get("1") is _Version1
assert legacy_versions.get("2") is _Version2
assert legacy_versions.get("latest") is _Version2

View File

@ -471,8 +471,8 @@ class TestCodeNodeInitialization:
assert node._get_description() is None assert node._get_description() is None
def test_get_base_node_data(self): def test_node_data_property(self):
"""Test get_base_node_data returns node data.""" """Test node_data property returns node data."""
node = CodeNode.__new__(CodeNode) node = CodeNode.__new__(CodeNode)
node._node_data = CodeNodeData( node._node_data = CodeNodeData(
title="Base Test", title="Base Test",
@ -482,7 +482,7 @@ class TestCodeNodeInitialization:
outputs={}, outputs={},
) )
result = node.get_base_node_data() result = node.node_data
assert result == node._node_data assert result == node._node_data
assert result.title == "Base Test" assert result.title == "Base Test"

View File

@ -240,8 +240,8 @@ class TestIterationNodeInitialization:
assert node._get_description() == "This is a description" assert node._get_description() == "This is a description"
def test_get_base_node_data(self): def test_node_data_property(self):
"""Test get_base_node_data returns node data.""" """Test node_data property returns node data."""
node = IterationNode.__new__(IterationNode) node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData( node._node_data = IterationNodeData(
title="Base Test", title="Base Test",
@ -249,7 +249,7 @@ class TestIterationNodeInitialization:
output_selector=["y"], output_selector=["y"],
) )
result = node.get_base_node_data() result = node.node_data
assert result == node._node_data assert result == node._node_data

View File

@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "sample-test" return "1"
def _run(self): def _run(self):
raise NotImplementedError raise NotImplementedError

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,718 @@
"""
Comprehensive unit tests for AudioService.
This test suite provides complete coverage of audio processing operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Speech-to-Text (ASR) Operations (TestAudioServiceASR)
Tests audio transcription functionality:
- Successful transcription for different app modes
- File validation (size, type, presence)
- Feature flag validation (speech-to-text enabled)
- Error handling for various failure scenarios
- Model instance availability checks
### 2. Text-to-Speech (TTS) Operations (TestAudioServiceTTS)
Tests text-to-audio conversion:
- TTS with text input
- TTS with message ID
- Voice selection (explicit and default)
- Feature flag validation (text-to-speech enabled)
- Draft workflow handling
- Streaming response handling
- Error handling for missing/invalid inputs
### 3. TTS Voice Listing (TestAudioServiceTTSVoices)
Tests available voice retrieval:
- Get available voices for a tenant
- Language filtering
- Error handling for missing provider
## Testing Approach
- **Mocking Strategy**: All external dependencies (ModelManager, db, FileStorage) are mocked
for fast, isolated unit tests
- **Factory Pattern**: AudioServiceTestDataFactory provides consistent test data
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values, side effects, and error conditions
## Key Concepts
**Audio Formats:**
- Supported: mp3, wav, m4a, flac, ogg, opus, webm
- File size limit: 30 MB
**App Modes:**
- ADVANCED_CHAT/WORKFLOW: Use workflow features
- CHAT/COMPLETION: Use app_model_config
**Feature Flags:**
- speech_to_text: Enables ASR functionality
- text_to_speech: Enables TTS functionality
"""
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
from werkzeug.datastructures import FileStorage
from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from models.workflow import Workflow
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechServiceError,
UnsupportedAudioTypeServiceError,
)
class AudioServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
audio-related operations.
"""
@staticmethod
def create_app_mock(
app_id: str = "app-123",
mode: AppMode = AppMode.CHAT,
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""
Create a mock App object.
Args:
app_id: Unique identifier for the app
mode: App mode (CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
tenant_id: Tenant identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock App object with specified attributes
"""
app = create_autospec(App, instance=True)
app.id = app_id
app.mode = mode
app.tenant_id = tenant_id
app.workflow = kwargs.get("workflow")
app.app_model_config = kwargs.get("app_model_config")
for key, value in kwargs.items():
setattr(app, key, value)
return app
@staticmethod
def create_workflow_mock(features_dict: dict | None = None, **kwargs) -> Mock:
"""
Create a mock Workflow object.
Args:
features_dict: Dictionary of workflow features
**kwargs: Additional attributes to set on the mock
Returns:
Mock Workflow object with specified attributes
"""
workflow = create_autospec(Workflow, instance=True)
workflow.features_dict = features_dict or {}
for key, value in kwargs.items():
setattr(workflow, key, value)
return workflow
@staticmethod
def create_app_model_config_mock(
speech_to_text_dict: dict | None = None,
text_to_speech_dict: dict | None = None,
**kwargs,
) -> Mock:
"""
Create a mock AppModelConfig object.
Args:
speech_to_text_dict: Speech-to-text configuration
text_to_speech_dict: Text-to-speech configuration
**kwargs: Additional attributes to set on the mock
Returns:
Mock AppModelConfig object with specified attributes
"""
config = create_autospec(AppModelConfig, instance=True)
config.speech_to_text_dict = speech_to_text_dict or {"enabled": False}
config.text_to_speech_dict = text_to_speech_dict or {"enabled": False}
for key, value in kwargs.items():
setattr(config, key, value)
return config
@staticmethod
def create_file_storage_mock(
filename: str = "test.mp3",
mimetype: str = "audio/mp3",
content: bytes = b"fake audio content",
**kwargs,
) -> Mock:
"""
Create a mock FileStorage object.
Args:
filename: Name of the file
mimetype: MIME type of the file
content: File content as bytes
**kwargs: Additional attributes to set on the mock
Returns:
Mock FileStorage object with specified attributes
"""
file = Mock(spec=FileStorage)
file.filename = filename
file.mimetype = mimetype
file.read = Mock(return_value=content)
for key, value in kwargs.items():
setattr(file, key, value)
return file
@staticmethod
def create_message_mock(
message_id: str = "msg-123",
answer: str = "Test answer",
status: MessageStatus = MessageStatus.NORMAL,
**kwargs,
) -> Mock:
"""
Create a mock Message object.
Args:
message_id: Unique identifier for the message
answer: Message answer text
status: Message status
**kwargs: Additional attributes to set on the mock
Returns:
Mock Message object with specified attributes
"""
message = create_autospec(Message, instance=True)
message.id = message_id
message.answer = answer
message.status = status
for key, value in kwargs.items():
setattr(message, key, value)
return message
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return AudioServiceTestDataFactory
class TestAudioServiceASR:
"""Test speech-to-text (ASR) operations."""
@patch("services.audio_service.ModelManager")
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in CHAT mode."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_asr(app_model=app, file=file, end_user="user-123")
# Assert
assert result == {"text": "Transcribed text"}
mock_model_instance.invoke_speech2text.assert_called_once()
call_args = mock_model_instance.invoke_speech2text.call_args
assert call_args.kwargs["user"] == "user-123"
@patch("services.audio_service.ModelManager")
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
# Arrange
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
app = factory.create_app_mock(
mode=AppMode.ADVANCED_CHAT,
workflow=workflow,
)
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_asr(app_model=app, file=file)
# Assert
assert result == {"text": "Workflow transcribed text"}
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock()
# Act & Assert
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
# Arrange
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
app = factory.create_app_mock(
mode=AppMode.WORKFLOW,
workflow=workflow,
)
file = factory.create_file_storage_mock()
# Act & Assert
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
# Arrange
app = factory.create_app_mock(
mode=AppMode.WORKFLOW,
workflow=None,
)
file = factory.create_file_storage_mock()
# Act & Assert
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
"""Test that ASR raises error when no file is uploaded."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Act & Assert
with pytest.raises(NoAudioUploadedServiceError):
AudioService.transcript_asr(app_model=app, file=None)
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
"""Test that ASR raises error for unsupported audio file types."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock(mimetype="video/mp4")
# Act & Assert
with pytest.raises(UnsupportedAudioTypeServiceError):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_for_large_file(self, factory):
"""Test that ASR raises error when file exceeds size limit (30MB)."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Create file larger than 30MB
large_content = b"x" * (31 * 1024 * 1024)
file = factory.create_file_storage_mock(content=large_content)
# Act & Assert
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
AudioService.transcript_asr(app_model=app, file=file)
@patch("services.audio_service.ModelManager")
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that ASR raises error when no model instance is available."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock()
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
with pytest.raises(ProviderNotSupportSpeechToTextServiceError):
AudioService.transcript_asr(app_model=app, file=file)
class TestAudioServiceTTS:
"""Test text-to-speech (TTS) operations."""
@patch("services.audio_service.ModelManager")
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
"""Test successful TTS with text input."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Hello world",
voice="en-US-Neural",
end_user="user-123",
)
# Assert
assert result == b"audio data"
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="Hello world",
user="user-123",
tenant_id=app.tenant_id,
voice="en-US-Neural",
)
@patch("services.audio_service.db.session")
@patch("services.audio_service.ModelManager")
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
"""Test successful TTS with message ID."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
message = factory.create_message_mock(
message_id="550e8400-e29b-41d4-a716-446655440000",
answer="Message answer text",
)
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = message
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio from message"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="550e8400-e29b-41d4-a716-446655440000",
)
# Assert
assert result == b"audio from message"
mock_model_instance.invoke_tts.assert_called_once()
@patch("services.audio_service.ModelManager")
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
"""Test TTS uses default voice when none specified."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True, "voice": "default-voice"}
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Test",
)
# Assert
assert result == b"audio data"
# Verify default voice was used
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "default-voice"
@patch("services.audio_service.ModelManager")
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
"""Test TTS gets first available voice when none is configured."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True} # No voice specified
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Test",
)
# Assert
assert result == b"audio data"
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "auto-voice"
@patch("services.audio_service.WorkflowService")
@patch("services.audio_service.ModelManager")
def test_transcript_tts_workflow_mode_with_draft(
self, mock_model_manager_class, mock_workflow_service_class, factory
):
"""Test TTS in WORKFLOW mode with draft workflow."""
# Arrange
draft_workflow = factory.create_workflow_mock(
features_dict={"text_to_speech": {"enabled": True, "voice": "draft-voice"}}
)
app = factory.create_app_mock(
mode=AppMode.WORKFLOW,
)
# Mock WorkflowService
mock_workflow_service = MagicMock()
mock_workflow_service_class.return_value = mock_workflow_service
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"draft audio"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Draft test",
is_draft=True,
)
# Assert
assert result == b"draft audio"
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
def test_transcript_tts_raises_error_when_text_missing(self, factory):
"""Test that TTS raises error when text is missing."""
# Arrange
app = factory.create_app_mock()
# Act & Assert
with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
"""Test that TTS returns None for invalid message ID format."""
# Arrange
app = factory.create_app_mock()
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="invalid-uuid",
)
# Assert
assert result is None
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
"""Test that TTS returns None when message doesn't exist."""
# Arrange
app = factory.create_app_mock()
# Mock database query returning None
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="550e8400-e29b-41d4-a716-446655440000",
)
# Assert
assert result is None
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
"""Test that TTS returns None when message answer is empty."""
# Arrange
app = factory.create_app_mock()
message = factory.create_message_mock(
answer="",
status=MessageStatus.NORMAL,
)
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = message
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="550e8400-e29b-41d4-a716-446655440000",
)
# Assert
assert result is None
@patch("services.audio_service.ModelManager")
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
"""Test that TTS raises error when no voices are available."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True} # No voice specified
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [] # No voices available
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act & Assert
with pytest.raises(ValueError, match="Sorry, no voice available"):
AudioService.transcript_tts(app_model=app, text="Test")
class TestAudioServiceTTSVoices:
"""Test TTS voice listing operations."""
@patch("services.audio_service.ModelManager")
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
"""Test successful retrieval of TTS voices."""
# Arrange
tenant_id = "tenant-123"
language = "en-US"
expected_voices = [
{"name": "Voice 1", "value": "voice-1"},
{"name": "Voice 2", "value": "voice-2"},
]
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = expected_voices
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
# Assert
assert result == expected_voices
mock_model_instance.get_tts_voices.assert_called_once_with(language)
@patch("services.audio_service.ModelManager")
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that TTS voices raises error when no model instance is available."""
# Arrange
tenant_id = "tenant-123"
language = "en-US"
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
@patch("services.audio_service.ModelManager")
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
"""Test that TTS voices propagates exceptions from model instance."""
# Arrange
tenant_id = "tenant-123"
language = "en-US"
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act & Assert
with pytest.raises(RuntimeError, match="Model error"):
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)

View File

@ -0,0 +1,494 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import App, DefaultEndUserSessionID, EndUser
from services.end_user_service import EndUserService
class TestEndUserServiceFactory:
"""Factory class for creating test data and mock objects for end user service tests."""
@staticmethod
def create_app_mock(
app_id: str = "app-123",
tenant_id: str = "tenant-456",
name: str = "Test App",
) -> MagicMock:
"""Create a mock App object."""
app = MagicMock(spec=App)
app.id = app_id
app.tenant_id = tenant_id
app.name = name
return app
@staticmethod
def create_end_user_mock(
user_id: str = "user-789",
tenant_id: str = "tenant-456",
app_id: str = "app-123",
session_id: str = "session-001",
type: InvokeFrom = InvokeFrom.SERVICE_API,
is_anonymous: bool = False,
) -> MagicMock:
"""Create a mock EndUser object."""
end_user = MagicMock(spec=EndUser)
end_user.id = user_id
end_user.tenant_id = tenant_id
end_user.app_id = app_id
end_user.session_id = session_id
end_user.type = type
end_user.is_anonymous = is_anonymous
end_user.external_user_id = session_id
return end_user
class TestEndUserServiceGetOrCreateEndUser:
"""
Unit tests for EndUserService.get_or_create_end_user method.
This test suite covers:
- Creating new end users
- Retrieving existing end users
- Default session ID handling
- Anonymous user creation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 01: Get or create with custom user_id
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_mock()
user_id = "custom-user-123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify the created user has correct attributes
added_user = mock_session.add.call_args[0][0]
assert added_user.tenant_id == app.tenant_id
assert added_user.app_id == app.id
assert added_user.session_id == user_id
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.is_anonymous is False
# Test 02: Get or create without user_id (default session)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_mock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
# Test 03: Get existing end user
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_mock()
user_id = "existing-user-123"
existing_user = factory.create_end_user_mock(
tenant_id=app.tenant_id,
app_id=app.id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result == existing_user
mock_session.add.assert_not_called() # Should not create new user
class TestEndUserServiceGetOrCreateEndUserByType:
"""
Unit tests for EndUserService.get_or_create_end_user_by_type method.
This test suite covers:
- Creating end users with different InvokeFrom types
- Type migration for legacy users
- Query ordering and prioritization
- Session management
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 04: Create new end user with SERVICE_API type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.tenant_id == tenant_id
assert added_user.app_id == app_id
assert added_user.session_id == user_id
# Test 05: Create new end user with WEB_APP type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.WEB_APP
# Test 06: Upgrade legacy end user type
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
# Existing user with old type
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with different type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
mock_session.commit.assert_called_once()
mock_logger.info.assert_called_once()
# Verify log message contains upgrade info
log_call = mock_logger.info.call_args[0][0]
assert "Upgrading legacy EndUser" in log_call
# Test 07: Get existing end user with matching type (no upgrade needed)
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with same type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.SERVICE_API
# No commit should be called (no type update needed)
mock_session.commit.assert_not_called()
mock_logger.info.assert_not_called()
# Test 08: Create anonymous user with default session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=None,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Test 09: Query ordering prioritizes matching type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
# Verify order_by was called (for type prioritization)
mock_query.order_by.assert_called_once()
# Test 10: Session context manager properly closes
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_session_context_manager_closes(self, mock_db, mock_session_class, factory):
"""Test that Session context manager is properly used."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_session
mock_session_class.return_value = mock_context
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
# Verify context manager was entered and exited
mock_context.__enter__.assert_called_once()
mock_context.__exit__.assert_called_once()
# Test 11: External user ID matches session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "custom-external-id"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.external_user_id == user_id
assert added_user.session_id == user_id
# Test 12: Different InvokeFrom types
@pytest.mark.parametrize(
"invoke_type",
[
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
],
)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
"""Test creating end users with different InvokeFrom types."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=invoke_type,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.type == invoke_type

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,649 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.model import App, AppMode, EndUser, Message
from services.errors.message import FirstMessageNotExistsError, LastMessageNotExistsError
from services.message_service import MessageService
class TestMessageServiceFactory:
"""Factory class for creating test data and mock objects for message service tests."""
@staticmethod
def create_app_mock(
app_id: str = "app-123",
mode: str = AppMode.ADVANCED_CHAT.value,
name: str = "Test App",
) -> MagicMock:
"""Create a mock App object."""
app = MagicMock(spec=App)
app.id = app_id
app.mode = mode
app.name = name
return app
@staticmethod
def create_end_user_mock(
user_id: str = "user-456",
session_id: str = "session-789",
) -> MagicMock:
"""Create a mock EndUser object."""
user = MagicMock(spec=EndUser)
user.id = user_id
user.session_id = session_id
return user
@staticmethod
def create_conversation_mock(
conversation_id: str = "conv-001",
app_id: str = "app-123",
) -> MagicMock:
"""Create a mock Conversation object."""
conversation = MagicMock()
conversation.id = conversation_id
conversation.app_id = app_id
return conversation
@staticmethod
def create_message_mock(
message_id: str = "msg-001",
conversation_id: str = "conv-001",
query: str = "What is AI?",
answer: str = "AI stands for Artificial Intelligence.",
created_at: datetime | None = None,
) -> MagicMock:
"""Create a mock Message object."""
message = MagicMock(spec=Message)
message.id = message_id
message.conversation_id = conversation_id
message.query = query
message.answer = answer
message.created_at = created_at or datetime.now()
return message
class TestMessageServicePaginationByFirstId:
"""
Unit tests for MessageService.pagination_by_first_id method.
This test suite covers:
- Basic pagination with and without first_id
- Order handling (asc/desc)
- Edge cases (no user, no conversation, invalid first_id)
- Has_more flag logic
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestMessageServiceFactory()
# Test 01: No user provided
def test_pagination_by_first_id_no_user(self, factory):
"""Test pagination returns empty result when no user is provided."""
# Arrange
app = factory.create_app_mock()
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=None,
conversation_id="conv-001",
first_id=None,
limit=10,
)
# Assert
assert isinstance(result, InfiniteScrollPagination)
assert result.data == []
assert result.limit == 10
assert result.has_more is False
# Test 02: No conversation_id provided
def test_pagination_by_first_id_no_conversation(self, factory):
"""Test pagination returns empty result when no conversation_id is provided."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="",
first_id=None,
limit=10,
)
# Assert
assert isinstance(result, InfiniteScrollPagination)
assert result.data == []
assert result.limit == 10
assert result.has_more is False
# Test 03: Basic pagination without first_id (desc order)
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_without_first_id_desc(self, mock_conversation_service, mock_db, factory):
"""Test basic pagination without first_id in descending order."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock()
mock_conversation_service.get_conversation.return_value = conversation
# Create 5 messages
messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(5)
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="conv-001",
first_id=None,
limit=10,
order="desc",
)
# Assert
assert len(result.data) == 5
assert result.has_more is False
assert result.limit == 10
# Messages should remain in desc order (not reversed)
assert result.data[0].id == "msg-000"
# Test 04: Basic pagination without first_id (asc order)
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_without_first_id_asc(self, mock_conversation_service, mock_db, factory):
"""Test basic pagination without first_id in ascending order."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock()
mock_conversation_service.get_conversation.return_value = conversation
# Create 5 messages (returned in desc order from DB)
messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, 4 - i), # Descending timestamps
)
for i in range(5)
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="conv-001",
first_id=None,
limit=10,
order="asc",
)
# Assert
assert len(result.data) == 5
assert result.has_more is False
# Messages should be reversed to asc order
assert result.data[0].id == "msg-004"
assert result.data[4].id == "msg-000"
# Test 05: Pagination with first_id
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, factory):
"""Test pagination with first_id to get messages before a specific message."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock()
mock_conversation_service.get_conversation.return_value = conversation
first_message = factory.create_message_mock(
message_id="msg-005",
created_at=datetime(2024, 1, 1, 12, 5),
)
# Messages before first_message
history_messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(5)
]
# Setup query mocks
mock_query_first = MagicMock()
mock_query_history = MagicMock()
def query_side_effect(*args):
if args[0] == Message:
# First call returns mock for first_message query
if not hasattr(query_side_effect, "call_count"):
query_side_effect.call_count = 0
query_side_effect.call_count += 1
if query_side_effect.call_count == 1:
return mock_query_first
else:
return mock_query_history
mock_db.session.query.side_effect = [mock_query_first, mock_query_history]
# Setup first message query
mock_query_first.where.return_value = mock_query_first
mock_query_first.first.return_value = first_message
# Setup history messages query
mock_query_history.where.return_value = mock_query_history
mock_query_history.order_by.return_value = mock_query_history
mock_query_history.limit.return_value = mock_query_history
mock_query_history.all.return_value = history_messages
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="conv-001",
first_id="msg-005",
limit=10,
order="desc",
)
# Assert
assert len(result.data) == 5
assert result.has_more is False
mock_query_first.where.assert_called_once()
mock_query_history.where.assert_called_once()
# Test 06: First message not found
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_first_message_not_exists(self, mock_conversation_service, mock_db, factory):
"""Test error handling when first_id doesn't exist."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock()
mock_conversation_service.get_conversation.return_value = conversation
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None # Message not found
# Act & Assert
with pytest.raises(FirstMessageNotExistsError):
MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="conv-001",
first_id="nonexistent-msg",
limit=10,
)
# Test 07: Has_more flag when results exceed limit
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, factory):
"""Test has_more flag is True when results exceed limit."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock()
mock_conversation_service.get_conversation.return_value = conversation
# Create limit+1 messages (11 messages for limit=10)
messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(11)
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="conv-001",
first_id=None,
limit=10,
)
# Assert
assert len(result.data) == 10 # Last message trimmed
assert result.has_more is True
assert result.limit == 10
# Test 08: Empty conversation
@patch("services.message_service.db")
@patch("services.message_service.ConversationService")
def test_pagination_by_first_id_empty_conversation(self, mock_conversation_service, mock_db, factory):
"""Test pagination with conversation that has no messages."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock()
mock_conversation_service.get_conversation.return_value = conversation
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = []
# Act
result = MessageService.pagination_by_first_id(
app_model=app,
user=user,
conversation_id="conv-001",
first_id=None,
limit=10,
)
# Assert
assert len(result.data) == 0
assert result.has_more is False
assert result.limit == 10
class TestMessageServicePaginationByLastId:
"""
Unit tests for MessageService.pagination_by_last_id method.
This test suite covers:
- Basic pagination with and without last_id
- Conversation filtering
- Include_ids filtering
- Edge cases (no user, invalid last_id)
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestMessageServiceFactory()
# Test 09: No user provided
def test_pagination_by_last_id_no_user(self, factory):
"""Test pagination returns empty result when no user is provided."""
# Arrange
app = factory.create_app_mock()
# Act
result = MessageService.pagination_by_last_id(
app_model=app,
user=None,
last_id=None,
limit=10,
)
# Assert
assert isinstance(result, InfiniteScrollPagination)
assert result.data == []
assert result.limit == 10
assert result.has_more is False
# Test 10: Basic pagination without last_id
@patch("services.message_service.db")
def test_pagination_by_last_id_without_last_id(self, mock_db, factory):
"""Test basic pagination without last_id."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(5)
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_last_id(
app_model=app,
user=user,
last_id=None,
limit=10,
)
# Assert
assert len(result.data) == 5
assert result.has_more is False
assert result.limit == 10
# Test 11: Pagination with last_id
@patch("services.message_service.db")
def test_pagination_by_last_id_with_last_id(self, mock_db, factory):
"""Test pagination with last_id to get messages after a specific message."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
last_message = factory.create_message_mock(
message_id="msg-005",
created_at=datetime(2024, 1, 1, 12, 5),
)
# Messages after last_message
new_messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(6, 10)
]
# Setup base query mock that returns itself for chaining
mock_base_query = MagicMock()
mock_db.session.query.return_value = mock_base_query
# First where() call for last_id lookup
mock_query_last = MagicMock()
mock_query_last.first.return_value = last_message
# Second where() call for history messages
mock_query_history = MagicMock()
mock_query_history.order_by.return_value = mock_query_history
mock_query_history.limit.return_value = mock_query_history
mock_query_history.all.return_value = new_messages
# Setup where() to return different mocks on consecutive calls
mock_base_query.where.side_effect = [mock_query_last, mock_query_history]
# Act
result = MessageService.pagination_by_last_id(
app_model=app,
user=user,
last_id="msg-005",
limit=10,
)
# Assert
assert len(result.data) == 4
assert result.has_more is False
# Test 12: Last message not found
@patch("services.message_service.db")
def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory):
"""Test error handling when last_id doesn't exist."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None # Message not found
# Act & Assert
with pytest.raises(LastMessageNotExistsError):
MessageService.pagination_by_last_id(
app_model=app,
user=user,
last_id="nonexistent-msg",
limit=10,
)
# Test 13: Pagination with conversation_id filter
@patch("services.message_service.ConversationService")
@patch("services.message_service.db")
def test_pagination_by_last_id_with_conversation_filter(self, mock_db, mock_conversation_service, factory):
"""Test pagination filtered by conversation_id."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
conversation = factory.create_conversation_mock(conversation_id="conv-001")
mock_conversation_service.get_conversation.return_value = conversation
messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
conversation_id="conv-001",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(5)
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_last_id(
app_model=app,
user=user,
last_id=None,
limit=10,
conversation_id="conv-001",
)
# Assert
assert len(result.data) == 5
assert result.has_more is False
# Verify conversation_id was used in query
mock_query.where.assert_called()
mock_conversation_service.get_conversation.assert_called_once()
# Test 14: Pagination with include_ids filter
@patch("services.message_service.db")
def test_pagination_by_last_id_with_include_ids(self, mock_db, factory):
"""Test pagination filtered by include_ids."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
# Only messages with IDs in include_ids should be returned
messages = [
factory.create_message_mock(message_id="msg-001"),
factory.create_message_mock(message_id="msg-003"),
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_last_id(
app_model=app,
user=user,
last_id=None,
limit=10,
include_ids=["msg-001", "msg-003"],
)
# Assert
assert len(result.data) == 2
assert result.data[0].id == "msg-001"
assert result.data[1].id == "msg-003"
# Test 15: Has_more flag when results exceed limit
@patch("services.message_service.db")
def test_pagination_by_last_id_has_more_true(self, mock_db, factory):
"""Test has_more flag is True when results exceed limit."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
# Create limit+1 messages (11 messages for limit=10)
messages = [
factory.create_message_mock(
message_id=f"msg-{i:03d}",
created_at=datetime(2024, 1, 1, 12, i),
)
for i in range(11)
]
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.limit.return_value = mock_query
mock_query.all.return_value = messages
# Act
result = MessageService.pagination_by_last_id(
app_model=app,
user=user,
last_id=None,
limit=10,
)
# Assert
assert len(result.data) == 10 # Last message trimmed
assert result.has_more is True
assert result.limit == 10

View File

@ -0,0 +1,440 @@
"""
Comprehensive unit tests for RecommendedAppService.
This test suite provides complete coverage of recommended app operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
Tests fetching recommended apps with categories:
- Successful retrieval with recommended apps
- Fallback to builtin when no recommended apps
- Different language support
- Factory mode selection (remote, builtin, db)
- Empty result handling
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
Tests fetching individual app details:
- Successful app detail retrieval
- Different factory modes
- App not found scenarios
- Language-specific details
## Testing Approach
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
are mocked for fast, isolated unit tests
- **Factory Pattern**: Tests verify correct factory selection based on mode
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values and factory method calls
## Key Concepts
**Factory Modes:**
- remote: Fetch from remote API
- builtin: Use built-in templates
- db: Fetch from database
**Fallback Logic:**
- If remote/db returns no apps, fallback to builtin en-US templates
- Ensures users always see some recommended apps
"""
from unittest.mock import MagicMock, patch
import pytest
from services.recommended_app_service import RecommendedAppService
class RecommendedAppServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
recommended app operations.
"""
@staticmethod
def create_recommended_apps_response(
recommended_apps: list[dict] | None = None,
categories: list[str] | None = None,
) -> dict:
"""
Create a mock response for recommended apps.
Args:
recommended_apps: List of recommended app dictionaries
categories: List of category names
Returns:
Dictionary with recommended_apps and categories
"""
if recommended_apps is None:
recommended_apps = [
{
"id": "app-1",
"name": "Test App 1",
"description": "Test description 1",
"category": "productivity",
},
{
"id": "app-2",
"name": "Test App 2",
"description": "Test description 2",
"category": "communication",
},
]
if categories is None:
categories = ["productivity", "communication", "utilities"]
return {
"recommended_apps": recommended_apps,
"categories": categories,
}
@staticmethod
def create_app_detail_response(
app_id: str = "app-123",
name: str = "Test App",
description: str = "Test description",
**kwargs,
) -> dict:
"""
Create a mock response for app detail.
Args:
app_id: App identifier
name: App name
description: App description
**kwargs: Additional fields
Returns:
Dictionary with app details
"""
detail = {
"id": app_id,
"name": name,
"description": description,
"category": kwargs.get("category", "productivity"),
"icon": kwargs.get("icon", "🚀"),
"model_config": kwargs.get("model_config", {}),
}
detail.update(kwargs)
return detail
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return RecommendedAppServiceTestDataFactory
class TestRecommendedAppServiceGetApps:
"""Test get_recommended_apps_and_categories operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of recommended apps when apps are returned."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected_response = factory.create_recommended_apps_response()
# Mock factory and retrieval instance
mock_retrieval_instance = MagicMock()
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
mock_factory = MagicMock()
mock_factory.return_value = mock_retrieval_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == expected_response
assert len(result["recommended_apps"]) == 2
assert len(result["categories"]) == 3
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
"""Test fallback to builtin when no recommended apps are returned."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
# Remote returns empty recommended_apps
empty_response = {"recommended_apps": [], "categories": []}
# Builtin fallback response
builtin_response = factory.create_recommended_apps_response(
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
)
# Mock remote retrieval instance (returns empty)
mock_remote_instance = MagicMock()
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
mock_remote_factory = MagicMock()
mock_remote_factory.return_value = mock_remote_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
# Mock builtin retrieval instance
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
# Assert
assert result == builtin_response
assert len(result["recommended_apps"]) == 1
assert result["recommended_apps"][0]["id"] == "builtin-1"
# Verify fallback was called with en-US (hardcoded)
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
"""Test fallback when recommended_apps key is None."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
# Response with None recommended_apps
none_response = {"recommended_apps": None, "categories": ["test"]}
# Builtin fallback response
builtin_response = factory.create_recommended_apps_response()
# Mock db retrieval instance (returns None)
mock_db_instance = MagicMock()
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
mock_db_factory = MagicMock()
mock_db_factory.return_value = mock_db_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
# Mock builtin retrieval instance
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
"""Test retrieval with different language codes."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
for language in languages:
# Create language-specific response
lang_response = factory.create_recommended_apps_response(
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommended_apps_and_categories(language)
# Assert
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
"""Test that correct factory is selected based on mode."""
# Arrange
modes = ["remote", "builtin", "db"]
for mode in modes:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
response = factory.create_recommended_apps_response()
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = response
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
class TestRecommendedAppServiceGetDetail:
"""Test get_recommend_app_detail operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of app detail."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "app-123"
expected_detail = factory.create_app_detail_response(
app_id=app_id,
name="Productivity App",
description="A great productivity app",
category="productivity",
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected_detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result == expected_detail
assert result["id"] == app_id
assert result["name"] == "Productivity App"
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
"""Test app detail retrieval with different factory modes."""
# Arrange
modes = ["remote", "builtin", "db"]
app_id = "test-app"
for mode in modes:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result["name"] == f"App from {mode}"
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
"""Test that None is returned when app is not found."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "nonexistent-app"
# Mock retrieval instance returning None
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = None
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
"""Test handling of empty dict response."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
app_id = "app-empty"
# Mock retrieval instance returning empty dict
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = {}
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
"""Test app detail with complex model configuration."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "complex-app"
complex_model_config = {
"provider": "openai",
"model": "gpt-4",
"parameters": {
"temperature": 0.7,
"max_tokens": 2000,
"top_p": 1.0,
},
}
expected_detail = factory.create_app_detail_response(
app_id=app_id,
name="Complex App",
model_config=complex_model_config,
workflows=["workflow-1", "workflow-2"],
tools=["tool-1", "tool-2", "tool-3"],
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected_detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result["model_config"] == complex_model_config
assert len(result["workflows"]) == 2
assert len(result["tools"]) == 3

View File

@ -0,0 +1,626 @@
"""
Comprehensive unit tests for SavedMessageService.
This test suite provides complete coverage of saved message operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Pagination (TestSavedMessageServicePagination)
Tests saved message listing and pagination:
- Pagination with valid user (Account and EndUser)
- Pagination without user raises ValueError
- Pagination with last_id parameter
- Empty results when no saved messages exist
- Integration with MessageService pagination
### 2. Save Operations (TestSavedMessageServiceSave)
Tests saving messages:
- Save message for Account user
- Save message for EndUser
- Save without user (no-op)
- Prevent duplicate saves (idempotent)
- Message validation through MessageService
### 3. Delete Operations (TestSavedMessageServiceDelete)
Tests deleting saved messages:
- Delete saved message for Account user
- Delete saved message for EndUser
- Delete without user (no-op)
- Delete non-existent saved message (no-op)
- Proper database cleanup
## Testing Approach
- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked
for fast, isolated unit tests
- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values and side effects
(database operations, method calls)
## Key Concepts
**User Types:**
- Account: Workspace members (console users)
- EndUser: API users (end users)
**Saved Messages:**
- Users can save messages for later reference
- Each user has their own saved message list
- Saving is idempotent (duplicate saves ignored)
- Deletion is safe (non-existent deletes ignored)
"""
from datetime import UTC, datetime
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.model import App, EndUser, Message
from models.web import SavedMessage
from services.saved_message_service import SavedMessageService
class SavedMessageServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
saved message operations.
"""
@staticmethod
def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
"""
Create a mock Account object.
Args:
account_id: Unique identifier for the account
**kwargs: Additional attributes to set on the mock
Returns:
Mock Account object with specified attributes
"""
account = create_autospec(Account, instance=True)
account.id = account_id
for key, value in kwargs.items():
setattr(account, key, value)
return account
@staticmethod
def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
"""
Create a mock EndUser object.
Args:
user_id: Unique identifier for the end user
**kwargs: Additional attributes to set on the mock
Returns:
Mock EndUser object with specified attributes
"""
user = create_autospec(EndUser, instance=True)
user.id = user_id
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
"""
Create a mock App object.
Args:
app_id: Unique identifier for the app
tenant_id: Tenant/workspace identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock App object with specified attributes
"""
app = create_autospec(App, instance=True)
app.id = app_id
app.tenant_id = tenant_id
app.name = kwargs.get("name", "Test App")
app.mode = kwargs.get("mode", "chat")
for key, value in kwargs.items():
setattr(app, key, value)
return app
@staticmethod
def create_message_mock(
message_id: str = "msg-123",
app_id: str = "app-123",
**kwargs,
) -> Mock:
"""
Create a mock Message object.
Args:
message_id: Unique identifier for the message
app_id: Associated app identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock Message object with specified attributes
"""
message = create_autospec(Message, instance=True)
message.id = message_id
message.app_id = app_id
message.query = kwargs.get("query", "Test query")
message.answer = kwargs.get("answer", "Test answer")
message.created_at = kwargs.get("created_at", datetime.now(UTC))
for key, value in kwargs.items():
setattr(message, key, value)
return message
@staticmethod
def create_saved_message_mock(
saved_message_id: str = "saved-123",
app_id: str = "app-123",
message_id: str = "msg-123",
created_by: str = "user-123",
created_by_role: str = "account",
**kwargs,
) -> Mock:
"""
Create a mock SavedMessage object.
Args:
saved_message_id: Unique identifier for the saved message
app_id: Associated app identifier
message_id: Associated message identifier
created_by: User who saved the message
created_by_role: Role of the user ('account' or 'end_user')
**kwargs: Additional attributes to set on the mock
Returns:
Mock SavedMessage object with specified attributes
"""
saved_message = create_autospec(SavedMessage, instance=True)
saved_message.id = saved_message_id
saved_message.app_id = app_id
saved_message.message_id = message_id
saved_message.created_by = created_by
saved_message.created_by_role = created_by_role
saved_message.created_at = kwargs.get("created_at", datetime.now(UTC))
for key, value in kwargs.items():
setattr(saved_message, key, value)
return saved_message
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return SavedMessageServiceTestDataFactory
class TestSavedMessageServicePagination:
"""Test saved message pagination operations."""
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an Account user."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
# Create saved messages for this user
saved_messages = [
factory.create_saved_message_mock(
saved_message_id=f"saved-{i}",
app_id=app.id,
message_id=f"msg-{i}",
created_by=user.id,
created_by_role="account",
)
for i in range(3)
]
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = saved_messages
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
# Assert
assert result == expected_pagination
mock_db_session.query.assert_called_once_with(SavedMessage)
# Verify MessageService was called with correct message IDs
mock_message_pagination.assert_called_once_with(
app_model=app,
user=user,
last_id=None,
limit=20,
include_ids=["msg-0", "msg-1", "msg-2"],
)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an EndUser."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
# Create saved messages for this end user
saved_messages = [
factory.create_saved_message_mock(
saved_message_id=f"saved-{i}",
app_id=app.id,
message_id=f"msg-{i}",
created_by=user.id,
created_by_role="end_user",
)
for i in range(2)
]
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = saved_messages
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10)
# Assert
assert result == expected_pagination
# Verify correct role was used in query
mock_message_pagination.assert_called_once_with(
app_model=app,
user=user,
last_id=None,
limit=10,
include_ids=["msg-0", "msg-1"],
)
def test_pagination_without_user_raises_error(self, factory):
"""Test that pagination without user raises ValueError."""
# Arrange
app = factory.create_app_mock()
# Act & Assert
with pytest.raises(ValueError, match="User is required"):
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with last_id parameter."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
last_id = "msg-last"
saved_messages = [
factory.create_saved_message_mock(
message_id=f"msg-{i}",
app_id=app.id,
created_by=user.id,
)
for i in range(5)
]
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = saved_messages
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10)
# Assert
assert result == expected_pagination
# Verify last_id was passed to MessageService
mock_message_pagination.assert_called_once()
call_args = mock_message_pagination.call_args
assert call_args.kwargs["last_id"] == last_id
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination when user has no saved messages."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
# Mock database query returning empty list
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
# Assert
assert result == expected_pagination
# Verify MessageService was called with empty include_ids
mock_message_pagination.assert_called_once_with(
app_model=app,
user=user,
last_id=None,
limit=20,
include_ids=[],
)
class TestSavedMessageServiceSave:
"""Test save message operations."""
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an Account user."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message = factory.create_message_mock(message_id="msg-123", app_id=app.id)
# Mock database query - no existing saved message
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock MessageService.get_message
mock_get_message.return_value = message
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
# Assert
mock_db_session.add.assert_called_once()
saved_message = mock_db_session.add.call_args[0][0]
assert saved_message.app_id == app.id
assert saved_message.message_id == message.id
assert saved_message.created_by == user.id
assert saved_message.created_by_role == "account"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an EndUser."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
message = factory.create_message_mock(message_id="msg-456", app_id=app.id)
# Mock database query - no existing saved message
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock MessageService.get_message
mock_get_message.return_value = message
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
# Assert
mock_db_session.add.assert_called_once()
saved_message = mock_db_session.add.call_args[0][0]
assert saved_message.app_id == app.id
assert saved_message.message_id == message.id
assert saved_message.created_by == user.id
assert saved_message.created_by_role == "end_user"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
def test_save_without_user_does_nothing(self, mock_db_session, factory):
"""Test that saving without user is a no-op."""
# Arrange
app = factory.create_app_mock()
# Act
SavedMessageService.save(app_model=app, user=None, message_id="msg-123")
# Assert
mock_db_session.query.assert_not_called()
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
"""Test that saving an already saved message is idempotent."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message_id = "msg-789"
# Mock database query - existing saved message found
existing_saved = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user.id,
created_by_role="account",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = existing_saved
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message_id)
# Assert - no new saved message created
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
mock_get_message.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
"""Test that save validates message exists through MessageService."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message = factory.create_message_mock()
# Mock database query - no existing saved message
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock MessageService.get_message
mock_get_message.return_value = message
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
# Assert - MessageService.get_message was called for validation
mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id)
class TestSavedMessageServiceDelete:
"""Test delete saved message operations."""
@patch("services.saved_message_service.db.session")
def test_delete_saved_message_for_account(self, mock_db_session, factory):
"""Test deleting a saved message for an Account user."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message_id = "msg-123"
# Mock database query - existing saved message found
saved_message = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user.id,
created_by_role="account",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = saved_message
# Act
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
# Assert
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
"""Test deleting a saved message for an EndUser."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
message_id = "msg-456"
# Mock database query - existing saved message found
saved_message = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user.id,
created_by_role="end_user",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = saved_message
# Act
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
# Assert
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
"""Test that deleting without user is a no-op."""
# Arrange
app = factory.create_app_mock()
# Act
SavedMessageService.delete(app_model=app, user=None, message_id="msg-123")
# Assert
mock_db_session.query.assert_not_called()
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
"""Test that deleting a non-existent saved message is a no-op."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message_id = "msg-nonexistent"
# Mock database query - no saved message found
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Act
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
# Assert - no deletion occurred
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
"""Test that delete only removes the user's own saved message."""
# Arrange
app = factory.create_app_mock()
user1 = factory.create_account_mock(account_id="user-1")
message_id = "msg-shared"
# Mock database query - finds user1's saved message
saved_message = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user1.id,
created_by_role="account",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = saved_message
# Act
SavedMessageService.delete(app_model=app, user=user1, message_id=message_id)
# Assert - only user1's saved message is deleted
mock_db_session.delete.assert_called_once_with(saved_message)
# Verify the query filters by user
assert mock_query.where.called

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
from unittest.mock import Mock from unittest.mock import Mock
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
@ -299,3 +299,154 @@ class TestToolTransformService:
param2 = result.parameters[1] param2 = result.parameters[1]
assert param2.name == "param2" assert param2.name == "param2"
assert param2.label == "Runtime Param 2" assert param2.label == "Runtime Param 2"
class TestWorkflowProviderToUserProvider:
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
workflow_app_id = "app_123"
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1", "label2"],
workflow_app_id=workflow_app_id,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "test_author"
assert result.name == "test_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["label1", "label2"]
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method without workflow_app_id
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1"],
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == ["label1"]
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method with explicit None values
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=None,
workflow_app_id=None,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == []
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller with various fields
workflow_app_id = "app_456"
provider_id = "provider_456"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "another_author"
mock_controller.entity.identity.name = "another_workflow_tool"
mock_controller.entity.identity.description = I18nObject(
en_US="Another description", zh_Hans="Another description"
)
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.label = I18nObject(
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
)
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["automation", "workflow"],
workflow_app_id=workflow_app_id,
)
# Verify all fields are preserved correctly
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.description.en_US == "Another description"
assert result.description.zh_Hans == "Another description"
assert result.icon == {"type": "emoji", "content": "⚙️"}
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
assert result.label.en_US == "Another Workflow Tool"
assert result.label.zh_Hans == "Another Workflow Tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["automation", "workflow"]
assert result.masked_credentials == {}
assert result.is_team_authorization is True
assert result.allow_delete is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1628,7 +1628,7 @@ dev = [
{ name = "celery-types", specifier = ">=0.23.0" }, { name = "celery-types", specifier = ">=0.23.0" },
{ name = "coverage", specifier = "~=7.2.4" }, { name = "coverage", specifier = "~=7.2.4" },
{ name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "dotenv-linter", specifier = "~=0.5.0" },
{ name = "faker", specifier = "~=32.1.0" }, { name = "faker", specifier = "~=38.2.0" },
{ name = "hypothesis", specifier = ">=6.131.15" }, { name = "hypothesis", specifier = ">=6.131.15" },
{ name = "import-linter", specifier = ">=2.3" }, { name = "import-linter", specifier = ">=2.3" },
{ name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "lxml-stubs", specifier = "~=0.5.1" },
@ -1859,15 +1859,14 @@ wheels = [
[[package]] [[package]]
name = "faker" name = "faker"
version = "32.1.0" version = "38.2.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "python-dateutil" }, { name = "tzdata" },
{ name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/1c/2a/dd2c8f55d69013d0eee30ec4c998250fb7da957f5fe860ed077b3df1725b/faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5", size = 1850193, upload-time = "2024-11-12T22:04:34.812Z" } sdist = { url = "https://files.pythonhosted.org/packages/64/27/022d4dbd4c20567b4c294f79a133cc2f05240ea61e0d515ead18c995c249/faker-38.2.0.tar.gz", hash = "sha256:20672803db9c7cb97f9b56c18c54b915b6f1d8991f63d1d673642dc43f5ce7ab", size = 1941469, upload-time = "2025-11-19T16:37:31.892Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/fa/4a82dea32d6262a96e6841cdd4a45c11ac09eecdff018e745565410ac70e/Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814", size = 1889123, upload-time = "2024-11-12T22:04:32.298Z" }, { url = "https://files.pythonhosted.org/packages/17/93/00c94d45f55c336434a15f98d906387e87ce28f9918e4444829a8fda432d/faker-38.2.0-py3-none-any.whl", hash = "sha256:35fe4a0a79dee0dc4103a6083ee9224941e7d3594811a50e3969e547b0d2ee65", size = 1980505, upload-time = "2025-11-19T16:37:30.208Z" },
] ]
[[package]] [[package]]

View File

@ -123,7 +123,7 @@ services:
# plugin daemon # plugin daemon
plugin_daemon: plugin_daemon:
image: langgenius/dify-plugin-daemon:0.4.0-local image: langgenius/dify-plugin-daemon:0.4.1-local
restart: always restart: always
env_file: env_file:
- ./middleware.env - ./middleware.env

View File

@ -1,6 +1,6 @@
'use client' 'use client'
import type { FC, PropsWithChildren } from 'react' import type { FC, PropsWithChildren } from 'react'
import useAccessControlStore from '../../../../context/access-control-store' import useAccessControlStore from '@/context/access-control-store'
import type { AccessMode } from '@/models/access-control' import type { AccessMode } from '@/models/access-control'
type AccessControlItemProps = PropsWithChildren<{ type AccessControlItemProps = PropsWithChildren<{
@ -8,7 +8,8 @@ type AccessControlItemProps = PropsWithChildren<{
}> }>
const AccessControlItem: FC<AccessControlItemProps> = ({ type, children }) => { const AccessControlItem: FC<AccessControlItemProps> = ({ type, children }) => {
const { currentMenu, setCurrentMenu } = useAccessControlStore(s => ({ currentMenu: s.currentMenu, setCurrentMenu: s.setCurrentMenu })) const currentMenu = useAccessControlStore(s => s.currentMenu)
const setCurrentMenu = useAccessControlStore(s => s.setCurrentMenu)
if (currentMenu !== type) { if (currentMenu !== type) {
return <div return <div
className="cursor-pointer rounded-[10px] border-[1px] className="cursor-pointer rounded-[10px] border-[1px]

View File

@ -23,7 +23,7 @@ const Empty = () => {
return ( return (
<> <>
<DefaultCards /> <DefaultCards />
<div className='absolute bottom-0 left-0 right-0 top-0 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'> <div className='absolute inset-0 z-20 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent pointer-events-none'>
<span className='system-md-medium text-text-tertiary'> <span className='system-md-medium text-text-tertiary'>
{t('app.newApp.noAppsFound')} {t('app.newApp.noAppsFound')}
</span> </span>

View File

@ -187,6 +187,19 @@ const GotoAnything: FC<Props> = ({
}, {} as { [key: string]: SearchResult[] }), }, {} as { [key: string]: SearchResult[] }),
[searchResults]) [searchResults])
useEffect(() => {
if (isCommandsMode)
return
if (!searchResults.length)
return
const currentValueExists = searchResults.some(result => `${result.type}-${result.id}` === cmdVal)
if (!currentValueExists)
setCmdVal(`${searchResults[0].type}-${searchResults[0].id}`)
}, [isCommandsMode, searchResults, cmdVal])
const emptyResult = useMemo(() => { const emptyResult = useMemo(() => {
if (searchResults.length || !searchQuery.trim() || isLoading || isCommandsMode) if (searchResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
return null return null
@ -386,7 +399,7 @@ const GotoAnything: FC<Props> = ({
<Command.Item <Command.Item
key={`${result.type}-${result.id}`} key={`${result.type}-${result.id}`}
value={`${result.type}-${result.id}`} value={`${result.type}-${result.id}`}
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover' className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt'
onSelect={() => handleNavigate(result)} onSelect={() => handleNavigate(result)}
> >
{result.icon} {result.icon}

View File

@ -52,7 +52,12 @@ const Nav = ({
`}> `}>
<Link href={link + (linkLastSearchParams && `?${linkLastSearchParams}`)}> <Link href={link + (linkLastSearchParams && `?${linkLastSearchParams}`)}>
<div <div
onClick={() => setAppDetail()} onClick={(e) => {
// Don't clear state if opening in new tab/window
if (e.metaKey || e.ctrlKey || e.shiftKey || e.button !== 0)
return
setAppDetail()
}}
className={classNames( className={classNames(
'flex h-7 cursor-pointer items-center rounded-[10px] px-2.5', 'flex h-7 cursor-pointer items-center rounded-[10px] px-2.5',
isActivated ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text', isActivated ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text',

View File

@ -77,6 +77,8 @@ export type Collection = {
timeout?: number timeout?: number
sse_read_timeout?: number sse_read_timeout?: number
} }
// Workflow
workflow_app_id?: string
} }
export type ToolParameter = { export type ToolParameter = {

View File

@ -1,5 +1,6 @@
import { import {
memo, memo,
useMemo,
} from 'react' } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useEdges } from 'reactflow' import { useEdges } from 'reactflow'
@ -16,6 +17,10 @@ import {
} from '@/app/components/workflow/hooks' } from '@/app/components/workflow/hooks'
import ShortcutsName from '@/app/components/workflow/shortcuts-name' import ShortcutsName from '@/app/components/workflow/shortcuts-name'
import type { Node } from '@/app/components/workflow/types' import type { Node } from '@/app/components/workflow/types'
import { BlockEnum } from '@/app/components/workflow/types'
import { CollectionType } from '@/app/components/tools/types'
import { useAllWorkflowTools } from '@/service/use-tools'
import { canFindTool } from '@/utils'
type PanelOperatorPopupProps = { type PanelOperatorPopupProps = {
id: string id: string
@ -45,6 +50,14 @@ const PanelOperatorPopup = ({
const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly
const isChildNode = !!(data.isInIteration || data.isInLoop) const isChildNode = !!(data.isInIteration || data.isInLoop)
const { data: workflowTools } = useAllWorkflowTools()
const isWorkflowTool = data.type === BlockEnum.Tool && data.provider_type === CollectionType.workflow
const workflowAppId = useMemo(() => {
if (!isWorkflowTool || !workflowTools || !data.provider_id) return undefined
const workflowTool = workflowTools.find(item => canFindTool(item.id, data.provider_id))
return workflowTool?.workflow_app_id
}, [isWorkflowTool, workflowTools, data.provider_id])
return ( return (
<div className='w-[240px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl'> <div className='w-[240px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl'>
{ {
@ -137,6 +150,22 @@ const PanelOperatorPopup = ({
</> </>
) )
} }
{
isWorkflowTool && workflowAppId && (
<>
<div className='p-1'>
<a
href={`/app/${workflowAppId}/workflow`}
target='_blank'
className='flex h-8 cursor-pointer items-center rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover'
>
{t('workflow.panel.openWorkflow')}
</a>
</div>
<div className='h-px bg-divider-regular'></div>
</>
)
}
{ {
showHelpLink && nodeMetaData.helpLinkUri && ( showHelpLink && nodeMetaData.helpLinkUri && (
<> <>

View File

@ -47,10 +47,8 @@ const ChatWrapper = (
const startVariables = startNode?.data.variables const startVariables = startNode?.data.variables
const appDetail = useAppStore(s => s.appDetail) const appDetail = useAppStore(s => s.appDetail)
const workflowStore = useWorkflowStore() const workflowStore = useWorkflowStore()
const { inputs, setInputs } = useStore(s => ({ const inputs = useStore(s => s.inputs)
inputs: s.inputs, const setInputs = useStore(s => s.setInputs)
setInputs: s.setInputs,
}))
const initialInputs = useMemo(() => { const initialInputs = useMemo(() => {
const initInputs: Record<string, any> = {} const initInputs: Record<string, any> = {}

View File

@ -32,10 +32,7 @@ type Props = {
const InputsPanel = ({ onRun }: Props) => { const InputsPanel = ({ onRun }: Props) => {
const { t } = useTranslation() const { t } = useTranslation()
const workflowStore = useWorkflowStore() const workflowStore = useWorkflowStore()
const { inputs } = useStore(s => ({ const inputs = useStore(s => s.inputs)
inputs: s.inputs,
setInputs: s.setInputs,
}))
const fileSettings = useHooksStore(s => s.configsMap?.fileSettings) const fileSettings = useHooksStore(s => s.configsMap?.fileSettings)
const nodes = useNodes<StartNodeType>() const nodes = useNodes<StartNodeType>()
const files = useStore(s => s.files) const files = useStore(s => s.files)

View File

@ -116,7 +116,7 @@ describe('useTabSearchParams', () => {
setActiveTab('settings') setActiveTab('settings')
}) })
expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings') expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings', { scroll: false })
expect(mockReplace).not.toHaveBeenCalled() expect(mockReplace).not.toHaveBeenCalled()
}) })
@ -137,7 +137,7 @@ describe('useTabSearchParams', () => {
setActiveTab('settings') setActiveTab('settings')
}) })
expect(mockReplace).toHaveBeenCalledWith('/test-path?category=settings') expect(mockReplace).toHaveBeenCalledWith('/test-path?category=settings', { scroll: false })
expect(mockPush).not.toHaveBeenCalled() expect(mockPush).not.toHaveBeenCalled()
}) })
@ -157,6 +157,7 @@ describe('useTabSearchParams', () => {
expect(mockPush).toHaveBeenCalledWith( expect(mockPush).toHaveBeenCalledWith(
'/test-path?category=settings%20%26%20config', '/test-path?category=settings%20%26%20config',
{ scroll: false },
) )
}) })
@ -211,7 +212,7 @@ describe('useTabSearchParams', () => {
setActiveTab('profile') setActiveTab('profile')
}) })
expect(mockPush).toHaveBeenCalledWith('/test-path?tab=profile') expect(mockPush).toHaveBeenCalledWith('/test-path?tab=profile', { scroll: false })
}) })
}) })
@ -294,7 +295,7 @@ describe('useTabSearchParams', () => {
const [activeTab] = result.current const [activeTab] = result.current
expect(activeTab).toBe('') expect(activeTab).toBe('')
expect(mockPush).toHaveBeenCalledWith('/test-path?category=') expect(mockPush).toHaveBeenCalledWith('/test-path?category=', { scroll: false })
}) })
/** /**
@ -345,7 +346,7 @@ describe('useTabSearchParams', () => {
setActiveTab('settings') setActiveTab('settings')
}) })
expect(mockPush).toHaveBeenCalledWith('/fallback-path?category=settings') expect(mockPush).toHaveBeenCalledWith('/fallback-path?category=settings', { scroll: false })
// Restore mock // Restore mock
;(usePathname as jest.Mock).mockReturnValue(mockPathname) ;(usePathname as jest.Mock).mockReturnValue(mockPathname)
@ -400,7 +401,7 @@ describe('useTabSearchParams', () => {
}) })
expect(result.current[0]).toBe('settings') expect(result.current[0]).toBe('settings')
expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings') expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings', { scroll: false })
// Change to profile tab // Change to profile tab
act(() => { act(() => {
@ -409,7 +410,7 @@ describe('useTabSearchParams', () => {
}) })
expect(result.current[0]).toBe('profile') expect(result.current[0]).toBe('profile')
expect(mockPush).toHaveBeenCalledWith('/test-path?category=profile') expect(mockPush).toHaveBeenCalledWith('/test-path?category=profile', { scroll: false })
// Verify push was called twice // Verify push was called twice
expect(mockPush).toHaveBeenCalledTimes(2) expect(mockPush).toHaveBeenCalledTimes(2)
@ -431,7 +432,7 @@ describe('useTabSearchParams', () => {
setActiveTab('advanced') setActiveTab('advanced')
}) })
expect(mockPush).toHaveBeenCalledWith('/app/123/settings?category=advanced') expect(mockPush).toHaveBeenCalledWith('/app/123/settings?category=advanced', { scroll: false })
// Restore mock // Restore mock
;(usePathname as jest.Mock).mockReturnValue(mockPathname) ;(usePathname as jest.Mock).mockReturnValue(mockPathname)

View File

@ -40,7 +40,7 @@ export const useTabSearchParams = ({
setTab(newActiveTab) setTab(newActiveTab)
if (disableSearchParams) if (disableSearchParams)
return return
router[`${routingBehavior}`](`${pathName}?${searchParamName}=${encodeURIComponent(newActiveTab)}`) router[`${routingBehavior}`](`${pathName}?${searchParamName}=${encodeURIComponent(newActiveTab)}`, { scroll: false })
} }
return [activeTab, setActiveTab] as const return [activeTab, setActiveTab] as const

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(optional & hidden)', optional_and_hidden: '(optional & hidden)',
goTo: 'Gehe zu', goTo: 'Gehe zu',
startNode: 'Startknoten', startNode: 'Startknoten',
openWorkflow: 'Workflow öffnen',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -400,6 +400,7 @@ const translation = {
userInputField: 'User Input Field', userInputField: 'User Input Field',
changeBlock: 'Change Node', changeBlock: 'Change Node',
helpLink: 'View Docs', helpLink: 'View Docs',
openWorkflow: 'Open Workflow',
about: 'About', about: 'About',
createdBy: 'Created By ', createdBy: 'Created By ',
nextStep: 'Next Step', nextStep: 'Next Step',

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(opcional y oculto)', optional_and_hidden: '(opcional y oculto)',
goTo: 'Ir a', goTo: 'Ir a',
startNode: 'Nodo de inicio', startNode: 'Nodo de inicio',
openWorkflow: 'Abrir flujo de trabajo',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(اختیاری و پنهان)', optional_and_hidden: '(اختیاری و پنهان)',
goTo: 'برو به', goTo: 'برو به',
startNode: 'گره شروع', startNode: 'گره شروع',
openWorkflow: 'باز کردن جریان کاری',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(optionnel et caché)', optional_and_hidden: '(optionnel et caché)',
goTo: 'Aller à', goTo: 'Aller à',
startNode: 'Nœud de départ', startNode: 'Nœud de départ',
openWorkflow: 'Ouvrir le flux de travail',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -386,6 +386,7 @@ const translation = {
optional_and_hidden: '(वैकल्पिक और छिपा हुआ)', optional_and_hidden: '(वैकल्पिक और छिपा हुआ)',
goTo: 'जाओ', goTo: 'जाओ',
startNode: 'प्रारंभ नोड', startNode: 'प्रारंभ नोड',
openWorkflow: 'वर्कफ़्लो खोलें',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -381,6 +381,7 @@ const translation = {
goTo: 'Pergi ke', goTo: 'Pergi ke',
startNode: 'Mulai Node', startNode: 'Mulai Node',
scrollToSelectedNode: 'Gulir ke node yang dipilih', scrollToSelectedNode: 'Gulir ke node yang dipilih',
openWorkflow: 'Buka Alur Kerja',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -389,6 +389,7 @@ const translation = {
optional_and_hidden: '(opzionale e nascosto)', optional_and_hidden: '(opzionale e nascosto)',
goTo: 'Vai a', goTo: 'Vai a',
startNode: 'Nodo iniziale', startNode: 'Nodo iniziale',
openWorkflow: 'Apri flusso di lavoro',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -401,6 +401,7 @@ const translation = {
minimize: '全画面を終了する', minimize: '全画面を終了する',
scrollToSelectedNode: '選択したノードまでスクロール', scrollToSelectedNode: '選択したノードまでスクロール',
optional_and_hidden: '(オプションおよび非表示)', optional_and_hidden: '(オプションおよび非表示)',
openWorkflow: 'ワークフローを開く',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -395,6 +395,7 @@ const translation = {
optional_and_hidden: '(선택 사항 및 숨김)', optional_and_hidden: '(선택 사항 및 숨김)',
goTo: '로 이동', goTo: '로 이동',
startNode: '시작 노드', startNode: '시작 노드',
openWorkflow: '워크플로 열기',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(opcjonalne i ukryte)', optional_and_hidden: '(opcjonalne i ukryte)',
goTo: 'Idź do', goTo: 'Idź do',
startNode: 'Węzeł początkowy', startNode: 'Węzeł początkowy',
openWorkflow: 'Otwórz przepływ pracy',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(opcional & oculto)', optional_and_hidden: '(opcional & oculto)',
goTo: 'Ir para', goTo: 'Ir para',
startNode: 'Iniciar Nó', startNode: 'Iniciar Nó',
openWorkflow: 'Abrir fluxo de trabalho',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(opțional și ascuns)', optional_and_hidden: '(opțional și ascuns)',
goTo: 'Du-te la', goTo: 'Du-te la',
startNode: 'Nod de start', startNode: 'Nod de start',
openWorkflow: 'Deschide fluxul de lucru',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(необязательно и скрыто)', optional_and_hidden: '(необязательно и скрыто)',
goTo: 'Перейти к', goTo: 'Перейти к',
startNode: 'Начальный узел', startNode: 'Начальный узел',
openWorkflow: 'Открыть рабочий процесс',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -381,6 +381,7 @@ const translation = {
optional_and_hidden: '(neobvezno in skrito)', optional_and_hidden: '(neobvezno in skrito)',
goTo: 'Pojdi na', goTo: 'Pojdi na',
startNode: 'Začetni vozel', startNode: 'Začetni vozel',
openWorkflow: 'Odpri delovni tok',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(ตัวเลือก & ซ่อน)', optional_and_hidden: '(ตัวเลือก & ซ่อน)',
goTo: 'ไปที่', goTo: 'ไปที่',
startNode: 'เริ่มต้นโหนด', startNode: 'เริ่มต้นโหนด',
openWorkflow: 'เปิดเวิร์กโฟลว์',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(isteğe bağlı ve gizli)', optional_and_hidden: '(isteğe bağlı ve gizli)',
goTo: 'Git', goTo: 'Git',
startNode: 'Başlangıç Düğümü', startNode: 'Başlangıç Düğümü',
openWorkflow: 'İş Akışını Aç',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(необов\'язково & приховано)', optional_and_hidden: '(необов\'язково & приховано)',
goTo: 'Перейти до', goTo: 'Перейти до',
startNode: 'Початковий вузол', startNode: 'Початковий вузол',
openWorkflow: 'Відкрити робочий процес',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -374,6 +374,7 @@ const translation = {
optional_and_hidden: '(tùy chọn & ẩn)', optional_and_hidden: '(tùy chọn & ẩn)',
goTo: 'Đi tới', goTo: 'Đi tới',
startNode: 'Nút Bắt đầu', startNode: 'Nút Bắt đầu',
openWorkflow: 'Mở quy trình làm việc',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -400,6 +400,7 @@ const translation = {
userInputField: '用户输入字段', userInputField: '用户输入字段',
changeBlock: '更改节点', changeBlock: '更改节点',
helpLink: '查看帮助文档', helpLink: '查看帮助文档',
openWorkflow: '打开工作流',
about: '关于', about: '关于',
createdBy: '作者', createdBy: '作者',
nextStep: '下一步', nextStep: '下一步',

View File

@ -379,6 +379,7 @@ const translation = {
optional_and_hidden: '(可選且隱藏)', optional_and_hidden: '(可選且隱藏)',
goTo: '前往', goTo: '前往',
startNode: '起始節點', startNode: '起始節點',
openWorkflow: '打開工作流程',
}, },
nodes: { nodes: {
common: { common: {

View File

@ -53,10 +53,10 @@
"@hookform/resolvers": "^3.10.0", "@hookform/resolvers": "^3.10.0",
"@lexical/code": "^0.36.2", "@lexical/code": "^0.36.2",
"@lexical/link": "^0.36.2", "@lexical/link": "^0.36.2",
"@lexical/list": "^0.36.2", "@lexical/list": "^0.38.2",
"@lexical/react": "^0.36.2", "@lexical/react": "^0.36.2",
"@lexical/selection": "^0.37.0", "@lexical/selection": "^0.37.0",
"@lexical/text": "^0.36.2", "@lexical/text": "^0.38.2",
"@lexical/utils": "^0.37.0", "@lexical/utils": "^0.37.0",
"@monaco-editor/react": "^4.7.0", "@monaco-editor/react": "^4.7.0",
"@octokit/core": "^6.1.6", "@octokit/core": "^6.1.6",
@ -79,7 +79,7 @@
"decimal.js": "^10.6.0", "decimal.js": "^10.6.0",
"dompurify": "^3.3.0", "dompurify": "^3.3.0",
"echarts": "^5.6.0", "echarts": "^5.6.0",
"echarts-for-react": "^3.0.2", "echarts-for-react": "^3.0.5",
"elkjs": "^0.9.3", "elkjs": "^0.9.3",
"emoji-mart": "^5.6.0", "emoji-mart": "^5.6.0",
"fast-deep-equal": "^3.1.3", "fast-deep-equal": "^3.1.3",
@ -141,7 +141,7 @@
"uuid": "^10.0.0", "uuid": "^10.0.0",
"zod": "^3.25.76", "zod": "^3.25.76",
"zundo": "^2.3.0", "zundo": "^2.3.0",
"zustand": "^4.5.7" "zustand": "^5.0.9"
}, },
"devDependencies": { "devDependencies": {
"@antfu/eslint-config": "^5.4.1", "@antfu/eslint-config": "^5.4.1",

View File

@ -85,8 +85,8 @@ importers:
specifier: ^0.36.2 specifier: ^0.36.2
version: 0.36.2 version: 0.36.2
'@lexical/list': '@lexical/list':
specifier: ^0.36.2 specifier: ^0.38.2
version: 0.36.2 version: 0.38.2
'@lexical/react': '@lexical/react':
specifier: ^0.36.2 specifier: ^0.36.2
version: 0.36.2(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(yjs@13.6.27) version: 0.36.2(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(yjs@13.6.27)
@ -94,8 +94,8 @@ importers:
specifier: ^0.37.0 specifier: ^0.37.0
version: 0.37.0 version: 0.37.0
'@lexical/text': '@lexical/text':
specifier: ^0.36.2 specifier: ^0.38.2
version: 0.36.2 version: 0.38.2
'@lexical/utils': '@lexical/utils':
specifier: ^0.37.0 specifier: ^0.37.0
version: 0.37.0 version: 0.37.0
@ -163,8 +163,8 @@ importers:
specifier: ^5.6.0 specifier: ^5.6.0
version: 5.6.0 version: 5.6.0
echarts-for-react: echarts-for-react:
specifier: ^3.0.2 specifier: ^3.0.5
version: 3.0.2(echarts@5.6.0)(react@19.1.1) version: 3.0.5(echarts@5.6.0)(react@19.1.1)
elkjs: elkjs:
specifier: ^0.9.3 specifier: ^0.9.3
version: 0.9.3 version: 0.9.3
@ -347,10 +347,10 @@ importers:
version: 3.25.76 version: 3.25.76
zundo: zundo:
specifier: ^2.3.0 specifier: ^2.3.0
version: 2.3.0(zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)) version: 2.3.0(zustand@5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1)))
zustand: zustand:
specifier: ^4.5.7 specifier: ^5.0.9
version: 4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1) version: 5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1))
devDependencies: devDependencies:
'@antfu/eslint-config': '@antfu/eslint-config':
specifier: ^5.4.1 specifier: ^5.4.1
@ -2009,6 +2009,9 @@ packages:
'@lexical/clipboard@0.37.0': '@lexical/clipboard@0.37.0':
resolution: {integrity: sha512-hRwASFX/ilaI5r8YOcZuQgONFshRgCPfdxfofNL7uruSFYAO6LkUhsjzZwUgf0DbmCJmbBADFw15FSthgCUhGA==} resolution: {integrity: sha512-hRwASFX/ilaI5r8YOcZuQgONFshRgCPfdxfofNL7uruSFYAO6LkUhsjzZwUgf0DbmCJmbBADFw15FSthgCUhGA==}
'@lexical/clipboard@0.38.2':
resolution: {integrity: sha512-dDShUplCu8/o6BB9ousr3uFZ9bltR+HtleF/Tl8FXFNPpZ4AXhbLKUoJuucRuIr+zqT7RxEv/3M6pk/HEoE6NQ==}
'@lexical/code@0.36.2': '@lexical/code@0.36.2':
resolution: {integrity: sha512-dfS62rNo3uKwNAJQ39zC+8gYX0k8UAoW7u+JPIqx+K2VPukZlvpsPLNGft15pdWBkHc7Pv+o9gJlB6gGv+EBfA==} resolution: {integrity: sha512-dfS62rNo3uKwNAJQ39zC+8gYX0k8UAoW7u+JPIqx+K2VPukZlvpsPLNGft15pdWBkHc7Pv+o9gJlB6gGv+EBfA==}
@ -2027,6 +2030,9 @@ packages:
'@lexical/extension@0.37.0': '@lexical/extension@0.37.0':
resolution: {integrity: sha512-Z58f2tIdz9bn8gltUu5cVg37qROGha38dUZv20gI2GeNugXAkoPzJYEcxlI1D/26tkevJ/7VaFUr9PTk+iKmaA==} resolution: {integrity: sha512-Z58f2tIdz9bn8gltUu5cVg37qROGha38dUZv20gI2GeNugXAkoPzJYEcxlI1D/26tkevJ/7VaFUr9PTk+iKmaA==}
'@lexical/extension@0.38.2':
resolution: {integrity: sha512-qbUNxEVjAC0kxp7hEMTzktj0/51SyJoIJWK6Gm790b4yNBq82fEPkksfuLkRg9VQUteD0RT1Nkjy8pho8nNamw==}
'@lexical/hashtag@0.36.2': '@lexical/hashtag@0.36.2':
resolution: {integrity: sha512-WdmKtzXFcahQT3ShFDeHF6LCR5C8yvFCj3ImI09rZwICrYeonbMrzsBUxS1joBz0HQ+ufF9Tx+RxLvGWx6WxzQ==} resolution: {integrity: sha512-WdmKtzXFcahQT3ShFDeHF6LCR5C8yvFCj3ImI09rZwICrYeonbMrzsBUxS1joBz0HQ+ufF9Tx+RxLvGWx6WxzQ==}
@ -2039,6 +2045,9 @@ packages:
'@lexical/html@0.37.0': '@lexical/html@0.37.0':
resolution: {integrity: sha512-oTsBc45eL8/lmF7fqGR+UCjrJYP04gumzf5nk4TczrxWL2pM4GIMLLKG1mpQI2H1MDiRLzq3T/xdI7Gh74z7Zw==} resolution: {integrity: sha512-oTsBc45eL8/lmF7fqGR+UCjrJYP04gumzf5nk4TczrxWL2pM4GIMLLKG1mpQI2H1MDiRLzq3T/xdI7Gh74z7Zw==}
'@lexical/html@0.38.2':
resolution: {integrity: sha512-pC5AV+07bmHistRwgG3NJzBMlIzSdxYO6rJU4eBNzyR4becdiLsI4iuv+aY7PhfSv+SCs7QJ9oc4i5caq48Pkg==}
'@lexical/link@0.36.2': '@lexical/link@0.36.2':
resolution: {integrity: sha512-Zb+DeHA1po8VMiOAAXsBmAHhfWmQttsUkI5oiZUmOXJruRuQ2rVr01NoxHpoEpLwHOABVNzD3PMbwov+g3c7lg==} resolution: {integrity: sha512-Zb+DeHA1po8VMiOAAXsBmAHhfWmQttsUkI5oiZUmOXJruRuQ2rVr01NoxHpoEpLwHOABVNzD3PMbwov+g3c7lg==}
@ -2048,6 +2057,9 @@ packages:
'@lexical/list@0.37.0': '@lexical/list@0.37.0':
resolution: {integrity: sha512-AOC6yAA3mfNvJKbwo+kvAbPJI+13yF2ISA65vbA578CugvJ08zIVgM+pSzxquGhD0ioJY3cXVW7+gdkCP1qu5g==} resolution: {integrity: sha512-AOC6yAA3mfNvJKbwo+kvAbPJI+13yF2ISA65vbA578CugvJ08zIVgM+pSzxquGhD0ioJY3cXVW7+gdkCP1qu5g==}
'@lexical/list@0.38.2':
resolution: {integrity: sha512-OQm9TzatlMrDZGxMxbozZEHzMJhKxAbH1TOnOGyFfzpfjbnFK2y8oLeVsfQZfZRmiqQS4Qc/rpFnRP2Ax5dsbA==}
'@lexical/mark@0.36.2': '@lexical/mark@0.36.2':
resolution: {integrity: sha512-n0MNXtGH+1i43hglgHjpQV0093HmIiFR7Budg2BJb8ZNzO1KZRqeXAHlA5ZzJ698FkAnS4R5bqG9tZ0JJHgAuA==} resolution: {integrity: sha512-n0MNXtGH+1i43hglgHjpQV0093HmIiFR7Budg2BJb8ZNzO1KZRqeXAHlA5ZzJ698FkAnS4R5bqG9tZ0JJHgAuA==}
@ -2078,21 +2090,33 @@ packages:
'@lexical/selection@0.37.0': '@lexical/selection@0.37.0':
resolution: {integrity: sha512-Lix1s2r71jHfsTEs4q/YqK2s3uXKOnyA3fd1VDMWysO+bZzRwEO5+qyDvENZ0WrXSDCnlibNFV1HttWX9/zqyw==} resolution: {integrity: sha512-Lix1s2r71jHfsTEs4q/YqK2s3uXKOnyA3fd1VDMWysO+bZzRwEO5+qyDvENZ0WrXSDCnlibNFV1HttWX9/zqyw==}
'@lexical/selection@0.38.2':
resolution: {integrity: sha512-eMFiWlBH6bEX9U9sMJ6PXPxVXTrihQfFeiIlWLuTpEIDF2HRz7Uo1KFRC/yN6q0DQaj7d9NZYA6Mei5DoQuz5w==}
'@lexical/table@0.36.2': '@lexical/table@0.36.2':
resolution: {integrity: sha512-96rNNPiVbC65i+Jn1QzIsehCS7UVUc69ovrh9Bt4+pXDebZSdZai153Q7RUq8q3AQ5ocK4/SA2kLQfMu0grj3Q==} resolution: {integrity: sha512-96rNNPiVbC65i+Jn1QzIsehCS7UVUc69ovrh9Bt4+pXDebZSdZai153Q7RUq8q3AQ5ocK4/SA2kLQfMu0grj3Q==}
'@lexical/table@0.37.0': '@lexical/table@0.37.0':
resolution: {integrity: sha512-g7S8ml8kIujEDLWlzYKETgPCQ2U9oeWqdytRuHjHGi/rjAAGHSej5IRqTPIMxNP3VVQHnBoQ+Y9hBtjiuddhgQ==} resolution: {integrity: sha512-g7S8ml8kIujEDLWlzYKETgPCQ2U9oeWqdytRuHjHGi/rjAAGHSej5IRqTPIMxNP3VVQHnBoQ+Y9hBtjiuddhgQ==}
'@lexical/table@0.38.2':
resolution: {integrity: sha512-uu0i7yz0nbClmHOO5ZFsinRJE6vQnFz2YPblYHAlNigiBedhqMwSv5bedrzDq8nTTHwych3mC63tcyKIrM+I1g==}
'@lexical/text@0.36.2': '@lexical/text@0.36.2':
resolution: {integrity: sha512-IbbqgRdMAD6Uk9b2+qSVoy+8RVcczrz6OgXvg39+EYD+XEC7Rbw7kDTWzuNSJJpP7vxSO8YDZSaIlP5gNH3qKA==} resolution: {integrity: sha512-IbbqgRdMAD6Uk9b2+qSVoy+8RVcczrz6OgXvg39+EYD+XEC7Rbw7kDTWzuNSJJpP7vxSO8YDZSaIlP5gNH3qKA==}
'@lexical/text@0.38.2':
resolution: {integrity: sha512-+juZxUugtC4T37aE3P0l4I9tsWbogDUnTI/mgYk4Ht9g+gLJnhQkzSA8chIyfTxbj5i0A8yWrUUSw+/xA7lKUQ==}
'@lexical/utils@0.36.2': '@lexical/utils@0.36.2':
resolution: {integrity: sha512-P9+t2Ob10YNGYT/PWEER+1EqH8SAjCNRn+7SBvKbr0IdleGF2JvzbJwAWaRwZs1c18P11XdQZ779dGvWlfwBIw==} resolution: {integrity: sha512-P9+t2Ob10YNGYT/PWEER+1EqH8SAjCNRn+7SBvKbr0IdleGF2JvzbJwAWaRwZs1c18P11XdQZ779dGvWlfwBIw==}
'@lexical/utils@0.37.0': '@lexical/utils@0.37.0':
resolution: {integrity: sha512-CFp4diY/kR5RqhzQSl/7SwsMod1sgLpI1FBifcOuJ6L/S6YywGpEB4B7aV5zqW21A/jU2T+2NZtxSUn6S+9gMg==} resolution: {integrity: sha512-CFp4diY/kR5RqhzQSl/7SwsMod1sgLpI1FBifcOuJ6L/S6YywGpEB4B7aV5zqW21A/jU2T+2NZtxSUn6S+9gMg==}
'@lexical/utils@0.38.2':
resolution: {integrity: sha512-y+3rw15r4oAWIEXicUdNjfk8018dbKl7dWHqGHVEtqzAYefnEYdfD2FJ5KOTXfeoYfxi8yOW7FvzS4NZDi8Bfw==}
'@lexical/yjs@0.36.2': '@lexical/yjs@0.36.2':
resolution: {integrity: sha512-gZ66Mw+uKXTO8KeX/hNKAinXbFg3gnNYraG76lBXCwb/Ka3q34upIY9FUeGOwGVaau3iIDQhE49I+6MugAX2FQ==} resolution: {integrity: sha512-gZ66Mw+uKXTO8KeX/hNKAinXbFg3gnNYraG76lBXCwb/Ka3q34upIY9FUeGOwGVaau3iIDQhE49I+6MugAX2FQ==}
peerDependencies: peerDependencies:
@ -4586,10 +4610,10 @@ packages:
duplexer@0.1.2: duplexer@0.1.2:
resolution: {integrity: sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==} resolution: {integrity: sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==}
echarts-for-react@3.0.2: echarts-for-react@3.0.5:
resolution: {integrity: sha512-DRwIiTzx8JfwPOVgGttDytBqdp5VzCSyMRIxubgU/g2n9y3VLUmF2FK7Icmg/sNVkv4+rktmrLN9w22U2yy3fA==} resolution: {integrity: sha512-YpEI5Ty7O/2nvCfQ7ybNa+S90DwE8KYZWacGvJW4luUqywP7qStQ+pxDlYOmr4jGDu10mhEkiAuMKcUlT4W5vg==}
peerDependencies: peerDependencies:
echarts: ^3.0.0 || ^4.0.0 || ^5.0.0 echarts: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0
react: ^15.0.0 || >=16.0.0 react: ^15.0.0 || >=16.0.0
echarts@5.6.0: echarts@5.6.0:
@ -8445,6 +8469,24 @@ packages:
react: react:
optional: true optional: true
zustand@5.0.9:
resolution: {integrity: sha512-ALBtUj0AfjJt3uNRQoL1tL2tMvj6Gp/6e39dnfT6uzpelGru8v1tPOGBzayOWbPJvujM8JojDk3E1LxeFisBNg==}
engines: {node: '>=12.20.0'}
peerDependencies:
'@types/react': ~19.1.17
immer: '>=9.0.6'
react: '>=18.0.0'
use-sync-external-store: '>=1.2.0'
peerDependenciesMeta:
'@types/react':
optional: true
immer:
optional: true
react:
optional: true
use-sync-external-store:
optional: true
zwitch@2.0.4: zwitch@2.0.4:
resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==}
@ -10200,6 +10242,14 @@ snapshots:
'@lexical/utils': 0.37.0 '@lexical/utils': 0.37.0
lexical: 0.37.0 lexical: 0.37.0
'@lexical/clipboard@0.38.2':
dependencies:
'@lexical/html': 0.38.2
'@lexical/list': 0.38.2
'@lexical/selection': 0.38.2
'@lexical/utils': 0.38.2
lexical: 0.37.0
'@lexical/code@0.36.2': '@lexical/code@0.36.2':
dependencies: dependencies:
'@lexical/utils': 0.36.2 '@lexical/utils': 0.36.2
@ -10234,6 +10284,12 @@ snapshots:
'@preact/signals-core': 1.12.1 '@preact/signals-core': 1.12.1
lexical: 0.37.0 lexical: 0.37.0
'@lexical/extension@0.38.2':
dependencies:
'@lexical/utils': 0.38.2
'@preact/signals-core': 1.12.1
lexical: 0.37.0
'@lexical/hashtag@0.36.2': '@lexical/hashtag@0.36.2':
dependencies: dependencies:
'@lexical/text': 0.36.2 '@lexical/text': 0.36.2
@ -10258,6 +10314,12 @@ snapshots:
'@lexical/utils': 0.37.0 '@lexical/utils': 0.37.0
lexical: 0.37.0 lexical: 0.37.0
'@lexical/html@0.38.2':
dependencies:
'@lexical/selection': 0.38.2
'@lexical/utils': 0.38.2
lexical: 0.37.0
'@lexical/link@0.36.2': '@lexical/link@0.36.2':
dependencies: dependencies:
'@lexical/extension': 0.36.2 '@lexical/extension': 0.36.2
@ -10278,6 +10340,13 @@ snapshots:
'@lexical/utils': 0.37.0 '@lexical/utils': 0.37.0
lexical: 0.37.0 lexical: 0.37.0
'@lexical/list@0.38.2':
dependencies:
'@lexical/extension': 0.38.2
'@lexical/selection': 0.38.2
'@lexical/utils': 0.38.2
lexical: 0.37.0
'@lexical/mark@0.36.2': '@lexical/mark@0.36.2':
dependencies: dependencies:
'@lexical/utils': 0.36.2 '@lexical/utils': 0.36.2
@ -10351,6 +10420,10 @@ snapshots:
dependencies: dependencies:
lexical: 0.37.0 lexical: 0.37.0
'@lexical/selection@0.38.2':
dependencies:
lexical: 0.37.0
'@lexical/table@0.36.2': '@lexical/table@0.36.2':
dependencies: dependencies:
'@lexical/clipboard': 0.36.2 '@lexical/clipboard': 0.36.2
@ -10365,10 +10438,21 @@ snapshots:
'@lexical/utils': 0.37.0 '@lexical/utils': 0.37.0
lexical: 0.37.0 lexical: 0.37.0
'@lexical/table@0.38.2':
dependencies:
'@lexical/clipboard': 0.38.2
'@lexical/extension': 0.38.2
'@lexical/utils': 0.38.2
lexical: 0.37.0
'@lexical/text@0.36.2': '@lexical/text@0.36.2':
dependencies: dependencies:
lexical: 0.37.0 lexical: 0.37.0
'@lexical/text@0.38.2':
dependencies:
lexical: 0.37.0
'@lexical/utils@0.36.2': '@lexical/utils@0.36.2':
dependencies: dependencies:
'@lexical/list': 0.36.2 '@lexical/list': 0.36.2
@ -10383,6 +10467,13 @@ snapshots:
'@lexical/table': 0.37.0 '@lexical/table': 0.37.0
lexical: 0.37.0 lexical: 0.37.0
'@lexical/utils@0.38.2':
dependencies:
'@lexical/list': 0.38.2
'@lexical/selection': 0.38.2
'@lexical/table': 0.38.2
lexical: 0.37.0
'@lexical/yjs@0.36.2(yjs@13.6.27)': '@lexical/yjs@0.36.2(yjs@13.6.27)':
dependencies: dependencies:
'@lexical/offset': 0.36.2 '@lexical/offset': 0.36.2
@ -13098,7 +13189,7 @@ snapshots:
duplexer@0.1.2: {} duplexer@0.1.2: {}
echarts-for-react@3.0.2(echarts@5.6.0)(react@19.1.1): echarts-for-react@3.0.5(echarts@5.6.0)(react@19.1.1):
dependencies: dependencies:
echarts: 5.6.0 echarts: 5.6.0
fast-deep-equal: 3.1.3 fast-deep-equal: 3.1.3
@ -17931,9 +18022,9 @@ snapshots:
dependencies: dependencies:
tslib: 2.3.0 tslib: 2.3.0
zundo@2.3.0(zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)): zundo@2.3.0(zustand@5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1))):
dependencies: dependencies:
zustand: 4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1) zustand: 5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1))
zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1): zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1):
dependencies: dependencies:
@ -17943,4 +18034,11 @@ snapshots:
immer: 10.1.3 immer: 10.1.3
react: 19.1.1 react: 19.1.1
zustand@5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1)):
optionalDependencies:
'@types/react': 19.1.17
immer: 10.1.3
react: 19.1.1
use-sync-external-store: 1.6.0(react@19.1.1)
zwitch@2.0.4: {} zwitch@2.0.4: {}