mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/memory-orchestration-fed
This commit is contained in:
commit
4e037d14d1
|
|
@ -0,0 +1,226 @@
|
|||
# CODEOWNERS
|
||||
# This file defines code ownership for the Dify project.
|
||||
# Each line is a file pattern followed by one or more owners.
|
||||
# Owners can be @username, @org/team-name, or email addresses.
|
||||
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
api/services/rag_pipeline/ @JohnJyong
|
||||
api/services/dataset_service.py @JohnJyong
|
||||
api/services/knowledge_service.py @JohnJyong
|
||||
api/services/external_knowledge_service.py @JohnJyong
|
||||
api/services/hit_testing_service.py @JohnJyong
|
||||
api/services/metadata_service.py @JohnJyong
|
||||
api/services/vector_service.py @JohnJyong
|
||||
api/services/entities/knowledge_entities/ @JohnJyong
|
||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
api/controllers/console/datasets/ @JohnJyong
|
||||
api/controllers/service_api/dataset/ @JohnJyong
|
||||
api/models/dataset.py @JohnJyong
|
||||
api/tasks/rag_pipeline/ @JohnJyong
|
||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
api/tasks/document_indexing_task.py @JohnJyong
|
||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
api/tasks/clean_dataset_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
api/core/trigger/ @Mairuis @Yeuoly
|
||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
api/services/trigger/ @Mairuis @Yeuoly
|
||||
api/models/trigger.py @Mairuis @Yeuoly
|
||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Async Workflow
|
||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Billing
|
||||
api/services/billing_service.py @hj24 @zyssyz123
|
||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
|
||||
# Backend - Enterprise
|
||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/feature_service.py @GarfieldDai @GareArc
|
||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Chat
|
||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Completion
|
||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - App - List and Creation
|
||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - Hit Testing
|
||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - List and Creation
|
||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - RAG - Documents List
|
||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Segments List
|
||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Settings
|
||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - Ecosystem - Plugins
|
||||
web/app/components/plugins/ @iamjoel @zhsama
|
||||
|
||||
# Frontend - Ecosystem - Tools
|
||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Ecosystem - MarketPlace
|
||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Login and Registration
|
||||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
web/app/account/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Service Authentication
|
||||
web/service/base.ts @douxc @iamjoel
|
||||
|
||||
# Frontend - WebApp Authentication and Access Control
|
||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Explore Page
|
||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Personal Settings
|
||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Analytics
|
||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Base Components
|
||||
web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
web/utils/time.ts @iamjoel @zxhlyh
|
||||
web/utils/format.ts @iamjoel @zxhlyh
|
||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
web/app/components/billing/ @iamjoel @zxhlyh
|
||||
web/app/education-apply/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
|
@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
|
|
@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
app_mode: str = Field(..., description="Application mode")
|
||||
model_mode: str = Field(..., description="Model mode")
|
||||
has_context: str = Field(default="true", description="Whether has context")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AdvancedPromptTemplateQuery.__name__,
|
||||
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -18,7 +25,7 @@ parser = (
|
|||
class AdvancedPromptTemplateList(Resource):
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
|
|
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = parser.parse_args()
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -36,6 +39,130 @@ from services.enterprise.enterprise_service import EnterpriseService
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
|
||||
class AppIconPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon data")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
||||
class AppSiteStatusPayload(BaseModel):
|
||||
enable_site: bool = Field(..., description="Enable or disable site")
|
||||
|
||||
|
||||
class AppApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable API")
|
||||
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
|
|
@ -147,22 +274,7 @@ app_pagination_model = console_ns.model(
|
|||
class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -172,42 +284,12 @@ class AppListApi(Resource):
|
|||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument("name", type=str, location="args", required=False)
|
||||
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
|
|
@ -254,19 +336,7 @@ class AppListApi(Resource):
|
|||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
|
|
@ -279,22 +349,10 @@ class AppListApi(Resource):
|
|||
def post(self):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
|
@ -326,20 +384,7 @@ class AppApi(Resource):
|
|||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
|
|
@ -351,28 +396,18 @@ class AppApi(Resource):
|
|||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
.add_argument("max_active_requests", type=int, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
|
|
@ -401,18 +436,7 @@ class AppCopyApi(Resource):
|
|||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -426,15 +450,7 @@ class AppCopyApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
|
|
@ -443,11 +459,11 @@ class AppCopyApi(Resource):
|
|||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -462,11 +478,7 @@ class AppExportApi(Resource):
|
|||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
|
|
@ -480,30 +492,23 @@ class AppExportApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# Add include_secret params
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
.add_argument("workflow_id", type=str, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -512,10 +517,10 @@ class AppNameApi(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = parser.parse_args()
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -525,16 +530,7 @@ class AppIconApi(Resource):
|
|||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -544,15 +540,10 @@ class AppIconApi(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -562,11 +553,7 @@ class AppSiteStatus(Resource):
|
|||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -576,11 +563,10 @@ class AppSiteStatus(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -590,11 +576,7 @@ class AppApiStatus(Resource):
|
|||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -604,11 +586,10 @@ class AppApiStatus(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -631,15 +612,7 @@ class AppTraceApi(Resource):
|
|||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -648,17 +621,12 @@ class AppTraceApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("enabled", type=bool, required=True, location="json")
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService
|
|||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(default="", description="Query text")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(..., description="User query")
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
|
|
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
|
|||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
|
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
@ -137,21 +154,7 @@ class ChatMessageApi(Resource):
|
|||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
|
|
@ -161,20 +164,10 @@ class ChatMessageApi(Resource):
|
|||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
|
@ -14,13 +16,54 @@ from extensions.ext_database import db
|
|||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString, TimestampField
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||
default="all", description="Annotation status filter"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class CompletionConversationQuery(BaseConversationQuery):
|
||||
pass
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
|
|
@ -283,22 +326,7 @@ class CompletionConversationApi(Resource):
|
|||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -309,32 +337,17 @@ class CompletionConversationApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args["keyword"]:
|
||||
if args.keyword:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args['keyword']}%"),
|
||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -342,7 +355,7 @@ class CompletionConversationApi(Resource):
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -354,11 +367,11 @@ class CompletionConversationApi(Resource):
|
|||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
|
|
@ -367,7 +380,7 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
|
@ -419,31 +432,7 @@ class ChatConversationApi(Resource):
|
|||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -454,31 +443,7 @@ class ChatConversationApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
|
|
@ -490,8 +455,8 @@ class ChatConversationApi(Resource):
|
|||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
|
|
@ -514,12 +479,12 @@ class ChatConversationApi(Resource):
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
|
|
@ -527,35 +492,35 @@ class ChatConversationApi(Resource):
|
|||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
|
|
@ -567,7 +532,7 @@ class ChatConversationApi(Resource):
|
|||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,6 +16,18 @@ from libs.login import login_required
|
|||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
|
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
|
|||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
|
|||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
|
|
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
|
|||
class RuleCodeGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
|
|||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
class InstructionGenerateApi(Resource):
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
.add_argument("current", type=str, required=False, default="", location="json")
|
||||
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
node = [node for node in nodes if node["id"] == args.node_id]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
return {"error": f"node {args.node_id} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
flow_id=args.flow_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
flow_id=args.flow_id,
|
||||
node_id=args.node_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
|
|
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
|
|||
class InstructionGenerationTemplateApi(Resource):
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||
match args.type:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
|
|
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
|
|||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
|
@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
|||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("first_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class FeedbackExportQuery(BaseModel):
|
||||
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||
|
||||
@field_validator("has_comment", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||
if isinstance(value, bool) or value is None:
|
||||
return value
|
||||
lowered = value.lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -157,12 +220,7 @@ class ChatMessageListApi(Resource):
|
|||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
|
|
@ -172,27 +230,21 @@ class ChatMessageListApi(Resource):
|
|||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args["first_id"]:
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -207,7 +259,7 @@ class ChatMessageListApi(Resource):
|
|||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
|
|
@ -215,12 +267,12 @@ class ChatMessageListApi(Resource):
|
|||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
if len(history_messages) == args.limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
|
|
@ -238,7 +290,7 @@ class ChatMessageListApi(Resource):
|
|||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
|
|
@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource):
|
|||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
|
|
@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource):
|
|||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
|
|
@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
rating_value = args.rating
|
||||
if rating_value is None:
|
||||
raise ValueError("rating is required to create feedback")
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
rating=rating_value,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
|
@ -369,24 +411,12 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
return {"data": questions}
|
||||
|
||||
|
||||
# Shared parser for feedback export (used for both documentation and runtime parsing)
|
||||
feedback_export_parser = (
|
||||
console_ns.parser()
|
||||
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
|
||||
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
|
||||
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
|
||||
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
|
||||
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(feedback_export_parser)
|
||||
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
|
|
@ -395,7 +425,7 @@ class MessageFeedbackExportApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = feedback_export_parser.parse_args()
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
|
@ -403,12 +433,12 @@ class MessageFeedbackExportApi(Resource):
|
|||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.get("from_source"),
|
||||
rating=args.get("rating"),
|
||||
has_comment=args.get("has_comment"),
|
||||
start_date=args.get("start_date"),
|
||||
end_date=args.get("end_date"),
|
||||
format_type=args.get("format", "csv"),
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
format_type=args.format,
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
|
|
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -57,7 +69,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -81,19 +93,12 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
|
|
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -121,7 +126,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
|
|
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -177,7 +182,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
|
|
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -235,7 +240,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
|
|
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -302,7 +307,7 @@ FROM
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
|
|
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -374,7 +379,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
|
|
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -436,7 +441,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
|
|
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -495,7 +500,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None:
|
|||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
features: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BaseWorkflowRunPayload(BaseModel):
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any] | None = None
|
||||
query: str = ""
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class IterationNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LoopNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
|
||||
|
||||
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
# of concerns and make the code more maintainable.
|
||||
|
|
@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource):
|
|||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncDraftWorkflowRequest",
|
||||
{
|
||||
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
|
||||
"features": fields.Raw(required=True, description="Workflow features configuration"),
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
|
|
@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource):
|
|||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
}
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
|
|
@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AdvancedChatWorkflowRunRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow run started successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json", default="")
|
||||
.add_argument("files", type=list, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
|
|
@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"IterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
|
|
@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
@console_ns.doc("run_workflow_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowIterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
|
|
@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"LoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
@console_ns.doc("run_workflow_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowLoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
@console_ns.doc("run_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||
@console_ns.response(200, "Draft workflow run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
|
|
@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
|
|
@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
@console_ns.doc("run_draft_workflow_node")
|
||||
@console_ns.doc(description="Run draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowNodeRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
Run draft workflow node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("query", type=str, required=False, location="json", default="")
|
||||
.add_argument("files", type=list, location="json", default=[])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
user_inputs = args.get("inputs")
|
||||
user_inputs = args_model.inputs
|
||||
if user_inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
|
|
@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_published_workflow")
|
||||
|
|
@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource):
|
|||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@console_ns.expect(parser_publish)
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
|
|||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@console_ns.doc("get_default_block_config")
|
||||
|
|
@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
|
|||
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@console_ns.response(200, "Default block configuration retrieved successfully")
|
||||
@console_ns.response(404, "Block type not found")
|
||||
@console_ns.expect(parser_block)
|
||||
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
if args.q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
filters = json.loads(args.q)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
|
|
@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
|
|||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@console_ns.expect(parser_convert)
|
||||
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if request.data:
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
payload = console_ns.payload or {}
|
||||
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(parser_workflows)
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource):
|
|||
@console_ns.doc("update_workflow_by_id")
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateWorkflowRequest",
|
||||
{
|
||||
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource):
|
|||
Update workflow attributes
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
if args.marked_name is not None:
|
||||
update_data["marked_name"] = args.marked_name
|
||||
if args.marked_comment is not None:
|
||||
update_data["marked_comment"] = args.marked_comment
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
|
@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_id", type=str, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||
node_id = args.node_id
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
|
@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
@console_ns.doc("draft_workflow_trigger_run_all")
|
||||
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
|
||||
@console_ns.response(200, "Workflow executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
|
|
@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_ids", type=list, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||
node_ids = args.node_ids
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from datetime import datetime
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -14,6 +17,48 @@ from models import App
|
|||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
status: WorkflowExecutionStatus | None = Field(
|
||||
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||
)
|
||||
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return isoparse(value) # type: ignore
|
||||
|
||||
@field_validator("detail", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
|
|
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
|
|||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
"created_at__before": "Filter logs created before this timestamp",
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
|
|||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
|
|||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
Parse common arguments for workflow run list endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_workflow_run_count_args():
|
||||
"""
|
||||
Parse common arguments for workflow run count endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@classmethod
|
||||
def validate_time_range(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
|
|
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args = _parse_workflow_run_list_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args = _parse_workflow_run_list_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
|
|||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
|
|
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class VersionApi(Resource):
|
|||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
|
|
|
|||
|
|
@ -174,63 +174,25 @@ class CheckEmailUniquePayload(BaseModel):
|
|||
return email(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceLanguagePayload.__name__,
|
||||
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceThemePayload.__name__,
|
||||
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountTimezonePayload.__name__,
|
||||
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountPasswordPayload.__name__,
|
||||
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletePayload.__name__,
|
||||
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletionFeedbackPayload.__name__,
|
||||
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationActivatePayload.__name__,
|
||||
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationAutocompleteQuery.__name__,
|
||||
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailSendPayload.__name__,
|
||||
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailValidityPayload.__name__,
|
||||
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailResetPayload.__name__,
|
||||
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CheckEmailUniquePayload.__name__,
|
||||
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
reg(AccountPasswordPayload)
|
||||
reg(AccountDeletePayload)
|
||||
reg(AccountDeletionFeedbackPayload)
|
||||
reg(EducationActivatePayload)
|
||||
reg(EducationAutocompleteQuery)
|
||||
reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
|
@ -7,21 +11,49 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EndpointCreatePayload(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointListQuery(BaseModel):
|
||||
page: int = Field(ge=1)
|
||||
page_size: int = Field(gt=0)
|
||||
|
||||
|
||||
class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointCreateRequest",
|
||||
{
|
||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
||||
"settings": fields.Raw(required=True, description="Endpoint settings"),
|
||||
"name": fields.String(required=True, description="Endpoint name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
|
|
@ -35,26 +67,16 @@ class EndpointCreateApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
|
|
@ -65,11 +87,7 @@ class EndpointCreateApi(Resource):
|
|||
class EndpointListApi(Resource):
|
||||
@console_ns.doc("list_endpoints")
|
||||
@console_ns.doc(description="List plugin endpoints with pagination")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointListQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -83,15 +101,10 @@ class EndpointListApi(Resource):
|
|||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -109,12 +122,7 @@ class EndpointListApi(Resource):
|
|||
class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.doc("list_plugin_endpoints")
|
||||
@console_ns.doc(description="List endpoints for a specific plugin")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -128,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
plugin_id = args.plugin_id
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -157,11 +159,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
|
|
@ -175,13 +173,12 @@ class EndpointDeleteApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -189,16 +186,7 @@ class EndpointDeleteApi(Resource):
|
|||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointUpdateRequest",
|
||||
{
|
||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
||||
"settings": fields.Raw(required=True, description="Updated settings"),
|
||||
"name": fields.String(required=True, description="Updated name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
|
|
@ -212,25 +200,15 @@ class EndpointUpdateApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("endpoint_id", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -239,11 +217,7 @@ class EndpointUpdateApi(Resource):
|
|||
class EndpointEnableApi(Resource):
|
||||
@console_ns.doc("enable_endpoint")
|
||||
@console_ns.doc(description="Enable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
|
|
@ -257,13 +231,12 @@ class EndpointEnableApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -271,11 +244,7 @@ class EndpointEnableApi(Resource):
|
|||
class EndpointDisableApi(Resource):
|
||||
@console_ns.doc("disable_endpoint")
|
||||
@console_ns.doc(description="Disable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
|
|
@ -289,11 +258,10 @@ class EndpointDisableApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,26 +58,15 @@ class OwnerTransferPayload(BaseModel):
|
|||
token: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
MemberInvitePayload.__name__,
|
||||
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
MemberRoleUpdatePayload.__name__,
|
||||
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferEmailPayload.__name__,
|
||||
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferCheckPayload.__name__,
|
||||
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferPayload.__name__,
|
||||
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(MemberInvitePayload)
|
||||
reg(MemberRoleUpdatePayload)
|
||||
reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
|
|
|
|||
|
|
@ -75,44 +75,18 @@ class ParserPreferredProviderType(BaseModel):
|
|||
preferred_provider_type: Literal["system", "custom"]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialId.__name__,
|
||||
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialCreate.__name__,
|
||||
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialUpdate.__name__,
|
||||
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialDelete.__name__,
|
||||
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialSwitch.__name__,
|
||||
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialValidate.__name__,
|
||||
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferredProviderType.__name__,
|
||||
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
reg(ParserModelList)
|
||||
reg(ParserCredentialId)
|
||||
reg(ParserCredentialCreate)
|
||||
reg(ParserCredentialUpdate)
|
||||
reg(ParserCredentialDelete)
|
||||
reg(ParserCredentialSwitch)
|
||||
reg(ParserCredentialValidate)
|
||||
reg(ParserPreferredProviderType)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
|
|
|
|||
|
|
@ -32,25 +32,11 @@ class ParserPostDefault(BaseModel):
|
|||
model_settings: list[Inner]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserDeleteModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingPayload(BaseModel):
|
||||
configs: list[dict[str, Any]] | None = None
|
||||
enabled: bool | None = None
|
||||
|
|
@ -119,33 +105,19 @@ class ParserParameter(BaseModel):
|
|||
model: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetCredentials.__name__,
|
||||
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCreateCredential.__name__,
|
||||
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUpdateCredential.__name__,
|
||||
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteCredential.__name__,
|
||||
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
reg(ParserGetDefault)
|
||||
reg(ParserPostDefault)
|
||||
reg(ParserDeleteModels)
|
||||
reg(ParserPostModels)
|
||||
reg(ParserGetCredentials)
|
||||
reg(ParserCreateCredential)
|
||||
reg(ParserUpdateCredential)
|
||||
reg(ParserDeleteCredential)
|
||||
reg(ParserParameter)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ from services.plugin.plugin_service import PluginService
|
|||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -46,9 +50,7 @@ class ParserList(BaseModel):
|
|||
page_size: int = Field(default=256)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
reg(ParserList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
|
|
@ -72,11 +74,6 @@ class ParserLatest(BaseModel):
|
|||
plugin_ids: list[str]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserIcon(BaseModel):
|
||||
tenant_id: str
|
||||
filename: str
|
||||
|
|
@ -173,72 +170,22 @@ class ParserReadme(BaseModel):
|
|||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifiers.__name__,
|
||||
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifierQuery.__name__,
|
||||
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserMarketplaceUpgrade.__name__,
|
||||
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPermissionChange.__name__,
|
||||
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDynamicOptions.__name__,
|
||||
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferencesChange.__name__,
|
||||
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserExcludePlugin.__name__,
|
||||
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
reg(ParserLatest)
|
||||
reg(ParserIcon)
|
||||
reg(ParserAsset)
|
||||
reg(ParserGithubUpload)
|
||||
reg(ParserPluginIdentifiers)
|
||||
reg(ParserGithubInstall)
|
||||
reg(ParserPluginIdentifierQuery)
|
||||
reg(ParserTasks)
|
||||
reg(ParserMarketplaceUpgrade)
|
||||
reg(ParserGithubUpgrade)
|
||||
reg(ParserUninstall)
|
||||
reg(ParserPermissionChange)
|
||||
reg(ParserDynamicOptions)
|
||||
reg(ParserPreferencesChange)
|
||||
reg(ParserExcludePlugin)
|
||||
reg(ParserReadme)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
|
|
|
|||
|
|
@ -54,25 +54,14 @@ class WorkspaceInfoPayload(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
SwitchWorkspacePayload.__name__,
|
||||
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceCustomConfigPayload.__name__,
|
||||
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceInfoPayload.__name__,
|
||||
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""
|
||||
|
|
@ -275,10 +275,8 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
|||
start_node_id: str | None = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
# Rebuild models that use forward references
|
||||
AppGenerateEntity.model_rebuild()
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild()
|
||||
ConversationAppGenerateEntity.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -58,11 +58,39 @@ class OceanBaseVector(BaseVector):
|
|||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
)
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
if self._client.check_table_exists(collection_name):
|
||||
self._load_collection_fields()
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OCEANBASE
|
||||
|
||||
def _load_collection_fields(self):
|
||||
"""
|
||||
Load collection fields from the database table.
|
||||
This method populates the _fields list with column names from the table.
|
||||
"""
|
||||
try:
|
||||
if self._collection_name in self._client.metadata_obj.tables:
|
||||
table = self._client.metadata_obj.tables[self._collection_name]
|
||||
# Store all column names except 'id' (primary key)
|
||||
self._fields = [column.name for column in table.columns if column.name != "id"]
|
||||
logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields)
|
||||
else:
|
||||
logger.warning("Collection '%s' not found in metadata", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e))
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
|
||||
:param field: Field name to check
|
||||
:return: True if field exists, False otherwise
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._vec_dim = len(embeddings[0])
|
||||
self._create_collection()
|
||||
|
|
@ -151,6 +179,7 @@ class OceanBaseVector(BaseVector):
|
|||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
self._load_collection_fields()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
|
|
@ -177,42 +206,134 @@ class OceanBaseVector(BaseVector):
|
|||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
try:
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to insert document with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to insert document with id '{id}'") from e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
return bool(cur.rowcount != 0)
|
||||
try:
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
return bool(cur.rowcount != 0)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to check if text exists with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to check text existence for id '{id}'") from e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
try:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to delete %d documents from collection '%s'",
|
||||
len(ids),
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
from sqlalchemy import text
|
||||
try:
|
||||
import re
|
||||
|
||||
cur = self._client.get(
|
||||
table_name=self._collection_name,
|
||||
ids=None,
|
||||
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
|
||||
output_column_name=["id"],
|
||||
)
|
||||
return [row[0] for row in cur]
|
||||
from sqlalchemy import text
|
||||
|
||||
# Validate key to prevent injection in JSON path
|
||||
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
|
||||
raise ValueError(f"Invalid characters in metadata key: {key}")
|
||||
|
||||
# Use parameterized query to prevent SQL injection
|
||||
sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value")
|
||||
|
||||
with self._client.engine.connect() as conn:
|
||||
result = conn.execute(sql, {"value": value})
|
||||
ids = [row[0] for row in result]
|
||||
|
||||
logger.debug(
|
||||
"Found %d documents with metadata field '%s'='%s' in collection '%s'",
|
||||
len(ids),
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to get IDs by metadata field '%s'='%s' in collection '%s'",
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
else:
|
||||
logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value)
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score"
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results as list of tuples (text, metadata, score)
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:param score_key: Key name for score in metadata
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for row in results:
|
||||
text, metadata_str, score = row[0], row[1], row[2]
|
||||
|
||||
# Parse metadata JSON
|
||||
try:
|
||||
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
|
||||
# Add score to metadata
|
||||
metadata[score_key] = score
|
||||
|
||||
# Filter by score threshold
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
if not self._hybrid_search_enabled:
|
||||
logger.warning(
|
||||
"Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)."
|
||||
)
|
||||
return []
|
||||
if not self.field_exists("text"):
|
||||
logger.warning(
|
||||
"Full-text search unavailable: collection '%s' missing 'text' field; "
|
||||
"recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.",
|
||||
self._collection_name,
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
|
|
@ -220,13 +341,24 @@ class OceanBaseVector(BaseVector):
|
|||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
|
||||
# Build parameterized query to prevent SQL injection
|
||||
from sqlalchemy import text
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
params = {"query": query}
|
||||
where_clause = ""
|
||||
|
||||
if document_ids_filter:
|
||||
# Create parameterized placeholders for document IDs
|
||||
placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter)))
|
||||
where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})"
|
||||
# Add document IDs to parameters
|
||||
for i, doc_id in enumerate(document_ids_filter):
|
||||
params[f"doc_id_{i}"] = doc_id
|
||||
|
||||
full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score
|
||||
FROM {self._collection_name}
|
||||
WHERE MATCH (text) AGAINST (:query) > 0
|
||||
{where_clause}
|
||||
|
|
@ -235,41 +367,45 @@ class OceanBaseVector(BaseVector):
|
|||
|
||||
with self._client.engine.connect() as conn:
|
||||
with conn.begin():
|
||||
from sqlalchemy import text
|
||||
|
||||
result = conn.execute(text(full_sql), {"query": query})
|
||||
result = conn.execute(text(full_sql), params)
|
||||
rows = result.fetchall()
|
||||
|
||||
docs = []
|
||||
for row in rows:
|
||||
metadata_str, _text, score = row
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=_text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
return self._process_search_results(rows, score_threshold=score_threshold)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fulltext search: %s.", str(e))
|
||||
return []
|
||||
logger.exception(
|
||||
"Failed to perform full-text search on collection '%s' with query '%s'",
|
||||
self._collection_name,
|
||||
query,
|
||||
)
|
||||
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from sqlalchemy import text
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
_where_clause = None
|
||||
if document_ids_filter:
|
||||
# Validate document IDs to prevent SQL injection
|
||||
# Document IDs should be alphanumeric with hyphens and underscores
|
||||
import re
|
||||
|
||||
for doc_id in document_ids_filter:
|
||||
if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id):
|
||||
raise ValueError(f"Invalid document ID format: {doc_id}")
|
||||
|
||||
# Safe to use in query after validation
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
from sqlalchemy import text
|
||||
|
||||
_where_clause = [text(where_clause)]
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
self._hnsw_ef_search = ef_search
|
||||
topk = kwargs.get("top_k", 10)
|
||||
try:
|
||||
score_threshold = float(val) if (val := kwargs.get("score_threshold")) is not None else 0.0
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid score_threshold parameter: {e}") from e
|
||||
try:
|
||||
cur = self._client.ann_search(
|
||||
table_name=self._collection_name,
|
||||
|
|
@ -282,21 +418,27 @@ class OceanBaseVector(BaseVector):
|
|||
where_clause=_where_clause,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception("Failed to search by vector. ", e)
|
||||
docs = []
|
||||
for _text, metadata, distance in cur:
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = 1 - distance / math.sqrt(2)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=_text,
|
||||
metadata=metadata,
|
||||
)
|
||||
logger.exception(
|
||||
"Failed to perform vector search on collection '%s'",
|
||||
self._collection_name,
|
||||
)
|
||||
return docs
|
||||
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
# Convert distance to score and prepare results for processing
|
||||
results = []
|
||||
for _text, metadata_str, distance in cur:
|
||||
score = 1 - distance / math.sqrt(2)
|
||||
results.append((_text, metadata_str, score))
|
||||
|
||||
return self._process_search_results(results, score_threshold=score_threshold)
|
||||
|
||||
def delete(self):
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
try:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
logger.debug("Dropped collection '%s'", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete collection '%s'", self._collection_name)
|
||||
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ class ToolProviderApiEntity(BaseModel):
|
|||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
# Workflow
|
||||
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -87,6 +89,8 @@ class ToolProviderApiEntity(BaseModel):
|
|||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
import pkgutil
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under core.workflow.nodes.*
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
# Only register concrete subclasses that define node_type and version()
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith("core.workflow.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
# External/test subclasses may register but must not override production
|
||||
bucket.setdefault(version, cls) # type: ignore[index]
|
||||
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
|
||||
version_keys = [v for v in bucket if v != "latest"]
|
||||
numeric_pairs: list[tuple[str, int]] = []
|
||||
for v in version_keys:
|
||||
numeric_pairs.append((v, int(v)))
|
||||
if numeric_pairs:
|
||||
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
|
||||
else:
|
||||
latest_key = max(version_keys) if version_keys else version
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
|
|
@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
|
|
@ -240,23 +275,23 @@ class Node(Generic[NodeDataT]):
|
|||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
|
|
@ -265,7 +300,7 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
|
|
@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]):
|
|||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
# Return a readonly view so callers can't mutate the registry by accident
|
||||
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
|
@ -419,10 +477,6 @@ class Node(Generic[NodeDataT]):
|
|||
"""Get the default values dictionary for this node."""
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
return self._node_data
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
|
|
@ -548,7 +602,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
|
|
@ -561,7 +615,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
|
@ -572,7 +626,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -586,7 +640,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -601,7 +655,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
|
|
@ -614,7 +668,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
|
@ -625,7 +679,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -639,7 +693,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
|
|||
|
|
@ -1,165 +1,9 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.trigger_plugin import TriggerEventNode
|
||||
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
||||
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.LOOP: {
|
||||
LATEST_VERSION: LoopNode,
|
||||
"1": LoopNode,
|
||||
},
|
||||
NodeType.LOOP_START: {
|
||||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.LOOP_END: {
|
||||
LATEST_VERSION: LoopEndNode,
|
||||
"1": LoopEndNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: {
|
||||
LATEST_VERSION: HumanInputNode,
|
||||
"1": HumanInputNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.TRIGGER_WEBHOOK: {
|
||||
LATEST_VERSION: TriggerWebhookNode,
|
||||
"1": TriggerWebhookNode,
|
||||
},
|
||||
NodeType.TRIGGER_PLUGIN: {
|
||||
LATEST_VERSION: TriggerEventNode,
|
||||
"1": TriggerEventNode,
|
||||
},
|
||||
NodeType.TRIGGER_SCHEDULE: {
|
||||
LATEST_VERSION: TriggerScheduleNode,
|
||||
"1": TriggerScheduleNode,
|
||||
},
|
||||
}
|
||||
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
|||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
|
|
@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
}
|
||||
if usage.total_tokens > 0:
|
||||
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
|
|
@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]):
|
|||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
return tool_runtime.latest_usage
|
||||
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
||||
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
||||
latest = getattr(tool_runtime, "latest_usage", None)
|
||||
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
||||
# for any name, so we must type-check here.
|
||||
if isinstance(latest, LLMUsage):
|
||||
return latest
|
||||
if isinstance(latest, dict):
|
||||
# Allow dict payloads from external runtimes
|
||||
return LLMUsage.model_validate(latest)
|
||||
# Fallback to empty usage when attribute is missing or not a valid payload
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
import logging
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
|
||||
|
||||
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
|
||||
Safe to run multiple times.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
RagPipelineGenerateEntity,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
|
||||
|
||||
ns = {"TraceQueueManager": TraceQueueManager}
|
||||
for Model in (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
WorkflowAppGenerateEntity,
|
||||
RagPipelineGenerateEntity,
|
||||
):
|
||||
try:
|
||||
Model.model_rebuild(_types_namespace=ns)
|
||||
except Exception as e:
|
||||
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
|
||||
except Exception as e:
|
||||
# Don't block app startup; just log at debug level.
|
||||
logger.debug("ext_forward_refs init skipped: %s", e)
|
||||
|
|
@ -111,7 +111,7 @@ package = false
|
|||
dev = [
|
||||
"coverage~=7.2.4",
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=32.1.0",
|
||||
"faker~=38.2.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"ty~=0.0.1a19",
|
||||
"basedpyright~=1.31.0",
|
||||
|
|
|
|||
|
|
@ -201,7 +201,9 @@ class ToolTransformService:
|
|||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
provider_controller: WorkflowToolProviderController,
|
||||
labels: list[str] | None = None,
|
||||
workflow_app_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
|
|
@ -221,6 +223,7 @@ class ToolTransformService:
|
|||
plugin_unique_identifier=None,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -189,6 +189,9 @@ class WorkflowToolManageService:
|
|||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
# Create a mapping from provider_id to app_id
|
||||
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
|
|
@ -202,8 +205,11 @@ class WorkflowToolManageService:
|
|||
result = []
|
||||
|
||||
for tool in tools:
|
||||
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
provider_controller=tool,
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Unit tests for ToolProviderApiEntity workflow_app_id field.
|
||||
|
||||
This test suite covers:
|
||||
- ToolProviderApiEntity workflow_app_id field creation and default value
|
||||
- ToolProviderApiEntity.to_dict() method behavior with workflow_app_id
|
||||
"""
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class TestToolProviderApiEntityWorkflowAppId:
|
||||
"""Test suite for ToolProviderApiEntity workflow_app_id field."""
|
||||
|
||||
def test_workflow_app_id_field_default_none(self):
|
||||
"""Test that workflow_app_id defaults to None when not provided."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
)
|
||||
|
||||
assert entity.workflow_app_id is None
|
||||
|
||||
def test_to_dict_includes_workflow_app_id_when_workflow_type_and_has_value(self):
|
||||
"""Test that to_dict() includes workflow_app_id when type is WORKFLOW and value is set."""
|
||||
workflow_app_id = "app_123"
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" in result
|
||||
assert result["workflow_app_id"] == workflow_app_id
|
||||
|
||||
def test_to_dict_excludes_workflow_app_id_when_workflow_type_and_none(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when type is WORKFLOW but value is None."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
||||
def test_to_dict_excludes_workflow_app_id_when_not_workflow_type(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when type is not WORKFLOW."""
|
||||
workflow_app_id = "app_123"
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
||||
def test_to_dict_includes_workflow_app_id_for_workflow_type_with_empty_string(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when value is empty string (falsy)."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id="",
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
|
@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
|
|||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "test"
|
||||
return "1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -744,7 +744,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
|||
)
|
||||
|
||||
llm_node = graph.nodes["llm"]
|
||||
base_node_data = llm_node.get_base_node_data()
|
||||
base_node_data = llm_node.node_data
|
||||
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock LLM node."""
|
||||
|
|
@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock agent node."""
|
||||
|
|
@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock tool node."""
|
||||
|
|
@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock knowledge retrieval node."""
|
||||
|
|
@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock HTTP request node."""
|
||||
|
|
@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock question classifier node."""
|
||||
|
|
@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock parameter extractor node."""
|
||||
|
|
@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock document extractor node."""
|
||||
|
|
@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
|
|
@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
|
|
@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock template transform node."""
|
||||
|
|
@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
|
|||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock code node."""
|
||||
|
|
|
|||
|
|
@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
|||
type_version_set: set[tuple[NodeType, str]] = set()
|
||||
|
||||
for cls in classes:
|
||||
# Only validate production node classes; skip test-defined subclasses and external helpers
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
if not module_name.startswith("core."):
|
||||
continue
|
||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||
node_type = cls.node_type
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
import types
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Import concrete nodes we will assert on (numeric version path)
|
||||
from core.workflow.nodes.variable_assigner.v1.node import (
|
||||
VariableAssignerNode as VariableAssignerV1,
|
||||
)
|
||||
from core.workflow.nodes.variable_assigner.v2.node import (
|
||||
VariableAssignerNode as VariableAssignerV2,
|
||||
)
|
||||
|
||||
|
||||
def test_variable_assigner_latest_prefers_highest_numeric_version():
|
||||
# Act
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert basic presence
|
||||
assert NodeType.VARIABLE_ASSIGNER in mapping
|
||||
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
|
||||
|
||||
# Both concrete versions must be present
|
||||
assert va_versions.get("1") is VariableAssignerV1
|
||||
assert va_versions.get("2") is VariableAssignerV2
|
||||
|
||||
# And latest should point to numerically-highest version ("2")
|
||||
assert va_versions.get("latest") is VariableAssignerV2
|
||||
|
||||
|
||||
def test_latest_prefers_highest_numeric_version():
|
||||
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
|
||||
# that has no concrete implementations in production to avoid interference.
|
||||
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
|
||||
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
|
||||
|
||||
def init_node_data(self, data):
|
||||
pass
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return None
|
||||
|
||||
def _get_retry_config(self):
|
||||
return types.SimpleNamespace() # not used
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version1"
|
||||
|
||||
def _get_description(self):
|
||||
return None
|
||||
|
||||
def _get_default_value_dict(self):
|
||||
return {}
|
||||
|
||||
def get_base_node_data(self):
|
||||
return types.SimpleNamespace(title="version1")
|
||||
|
||||
class _Version2(_Version1): # type: ignore[misc]
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version2"
|
||||
|
||||
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
|
||||
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
|
||||
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
|
||||
|
||||
assert legacy_versions.get("1") is _Version1
|
||||
assert legacy_versions.get("2") is _Version2
|
||||
assert legacy_versions.get("latest") is _Version2
|
||||
|
|
@ -471,8 +471,8 @@ class TestCodeNodeInitialization:
|
|||
|
||||
assert node._get_description() is None
|
||||
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
def test_node_data_property(self):
|
||||
"""Test node_data property returns node data."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
node._node_data = CodeNodeData(
|
||||
title="Base Test",
|
||||
|
|
@ -482,7 +482,7 @@ class TestCodeNodeInitialization:
|
|||
outputs={},
|
||||
)
|
||||
|
||||
result = node.get_base_node_data()
|
||||
result = node.node_data
|
||||
|
||||
assert result == node._node_data
|
||||
assert result.title == "Base Test"
|
||||
|
|
|
|||
|
|
@ -240,8 +240,8 @@ class TestIterationNodeInitialization:
|
|||
|
||||
assert node._get_description() == "This is a description"
|
||||
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
def test_node_data_property(self):
|
||||
"""Test node_data property returns node data."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Base Test",
|
||||
|
|
@ -249,7 +249,7 @@ class TestIterationNodeInitialization:
|
|||
output_selector=["y"],
|
||||
)
|
||||
|
||||
result = node.get_base_node_data()
|
||||
result = node.node_data
|
||||
|
||||
assert result == node._node_data
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
|
|||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "sample-test"
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,718 @@
|
|||
"""
|
||||
Comprehensive unit tests for AudioService.
|
||||
|
||||
This test suite provides complete coverage of audio processing operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Speech-to-Text (ASR) Operations (TestAudioServiceASR)
|
||||
Tests audio transcription functionality:
|
||||
- Successful transcription for different app modes
|
||||
- File validation (size, type, presence)
|
||||
- Feature flag validation (speech-to-text enabled)
|
||||
- Error handling for various failure scenarios
|
||||
- Model instance availability checks
|
||||
|
||||
### 2. Text-to-Speech (TTS) Operations (TestAudioServiceTTS)
|
||||
Tests text-to-audio conversion:
|
||||
- TTS with text input
|
||||
- TTS with message ID
|
||||
- Voice selection (explicit and default)
|
||||
- Feature flag validation (text-to-speech enabled)
|
||||
- Draft workflow handling
|
||||
- Streaming response handling
|
||||
- Error handling for missing/invalid inputs
|
||||
|
||||
### 3. TTS Voice Listing (TestAudioServiceTTSVoices)
|
||||
Tests available voice retrieval:
|
||||
- Get available voices for a tenant
|
||||
- Language filtering
|
||||
- Error handling for missing provider
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (ModelManager, db, FileStorage) are mocked
|
||||
for fast, isolated unit tests
|
||||
- **Factory Pattern**: AudioServiceTestDataFactory provides consistent test data
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values, side effects, and error conditions
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Audio Formats:**
|
||||
- Supported: mp3, wav, m4a, flac, ogg, opus, webm
|
||||
- File size limit: 30 MB
|
||||
|
||||
**App Modes:**
|
||||
- ADVANCED_CHAT/WORKFLOW: Use workflow features
|
||||
- CHAT/COMPLETION: Use app_model_config
|
||||
|
||||
**Feature Flags:**
|
||||
- speech_to_text: Enables ASR functionality
|
||||
- text_to_speech: Enables TTS functionality
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, AppModelConfig, Message
|
||||
from models.workflow import Workflow
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class AudioServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
audio-related operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
mode: AppMode = AppMode.CHAT,
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock App object.
|
||||
|
||||
Args:
|
||||
app_id: Unique identifier for the app
|
||||
mode: App mode (CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
|
||||
tenant_id: Tenant identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock App object with specified attributes
|
||||
"""
|
||||
app = create_autospec(App, instance=True)
|
||||
app.id = app_id
|
||||
app.mode = mode
|
||||
app.tenant_id = tenant_id
|
||||
app.workflow = kwargs.get("workflow")
|
||||
app.app_model_config = kwargs.get("app_model_config")
|
||||
for key, value in kwargs.items():
|
||||
setattr(app, key, value)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_mock(features_dict: dict | None = None, **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock Workflow object.
|
||||
|
||||
Args:
|
||||
features_dict: Dictionary of workflow features
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Workflow object with specified attributes
|
||||
"""
|
||||
workflow = create_autospec(Workflow, instance=True)
|
||||
workflow.features_dict = features_dict or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(workflow, key, value)
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def create_app_model_config_mock(
|
||||
speech_to_text_dict: dict | None = None,
|
||||
text_to_speech_dict: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock AppModelConfig object.
|
||||
|
||||
Args:
|
||||
speech_to_text_dict: Speech-to-text configuration
|
||||
text_to_speech_dict: Text-to-speech configuration
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock AppModelConfig object with specified attributes
|
||||
"""
|
||||
config = create_autospec(AppModelConfig, instance=True)
|
||||
config.speech_to_text_dict = speech_to_text_dict or {"enabled": False}
|
||||
config.text_to_speech_dict = text_to_speech_dict or {"enabled": False}
|
||||
for key, value in kwargs.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def create_file_storage_mock(
|
||||
filename: str = "test.mp3",
|
||||
mimetype: str = "audio/mp3",
|
||||
content: bytes = b"fake audio content",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock FileStorage object.
|
||||
|
||||
Args:
|
||||
filename: Name of the file
|
||||
mimetype: MIME type of the file
|
||||
content: File content as bytes
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock FileStorage object with specified attributes
|
||||
"""
|
||||
file = Mock(spec=FileStorage)
|
||||
file.filename = filename
|
||||
file.mimetype = mimetype
|
||||
file.read = Mock(return_value=content)
|
||||
for key, value in kwargs.items():
|
||||
setattr(file, key, value)
|
||||
return file
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-123",
|
||||
answer: str = "Test answer",
|
||||
status: MessageStatus = MessageStatus.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Message object.
|
||||
|
||||
Args:
|
||||
message_id: Unique identifier for the message
|
||||
answer: Message answer text
|
||||
status: Message status
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Message object with specified attributes
|
||||
"""
|
||||
message = create_autospec(Message, instance=True)
|
||||
message.id = message_id
|
||||
message.answer = answer
|
||||
message.status = status
|
||||
for key, value in kwargs.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return AudioServiceTestDataFactory
|
||||
|
||||
|
||||
class TestAudioServiceASR:
|
||||
"""Test speech-to-text (ASR) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_asr(app_model=app, file=file, end_user="user-123")
|
||||
|
||||
# Assert
|
||||
assert result == {"text": "Transcribed text"}
|
||||
mock_model_instance.invoke_speech2text.assert_called_once()
|
||||
call_args = mock_model_instance.invoke_speech2text.call_args
|
||||
assert call_args.kwargs["user"] == "user-123"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=workflow,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
# Assert
|
||||
assert result == {"text": "Workflow transcribed text"}
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
workflow=workflow,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
|
||||
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
workflow=None,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
|
||||
"""Test that ASR raises error when no file is uploaded."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoAudioUploadedServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=None)
|
||||
|
||||
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
|
||||
"""Test that ASR raises error for unsupported audio file types."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock(mimetype="video/mp4")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UnsupportedAudioTypeServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_for_large_file(self, factory):
|
||||
"""Test that ASR raises error when file exceeds size limit (30MB)."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
# Create file larger than 30MB
|
||||
large_content = b"x" * (31 * 1024 * 1024)
|
||||
file = factory.create_file_storage_mock(content=large_content)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that ASR raises error when no model instance is available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
|
||||
class TestAudioServiceTTS:
|
||||
"""Test text-to-speech (TTS) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful TTS with text input."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Hello world",
|
||||
voice="en-US-Neural",
|
||||
end_user="user-123",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="Hello world",
|
||||
user="user-123",
|
||||
tenant_id=app.tenant_id,
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
|
||||
"""Test successful TTS with message ID."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
message = factory.create_message_mock(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
answer="Message answer text",
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio from message"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio from message"
|
||||
mock_model_instance.invoke_tts.assert_called_once()
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
|
||||
"""Test TTS uses default voice when none specified."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "default-voice"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
# Verify default voice was used
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "default-voice"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
|
||||
"""Test TTS gets first available voice when none is configured."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True} # No voice specified
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "auto-voice"
|
||||
|
||||
@patch("services.audio_service.WorkflowService")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_workflow_mode_with_draft(
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory
|
||||
):
|
||||
"""Test TTS in WORKFLOW mode with draft workflow."""
|
||||
# Arrange
|
||||
draft_workflow = factory.create_workflow_mock(
|
||||
features_dict={"text_to_speech": {"enabled": True, "voice": "draft-voice"}}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
)
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow_service = MagicMock()
|
||||
mock_workflow_service_class.return_value = mock_workflow_service
|
||||
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"draft audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Draft test",
|
||||
is_draft=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"draft audio"
|
||||
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
|
||||
|
||||
def test_transcript_tts_raises_error_when_text_missing(self, factory):
|
||||
"""Test that TTS raises error when text is missing."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None for invalid message ID format."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="invalid-uuid",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message answer is empty."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
message = factory.create_message_mock(
|
||||
answer="",
|
||||
status=MessageStatus.NORMAL,
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS raises error when no voices are available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True} # No voice specified
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [] # No voices available
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Sorry, no voice available"):
|
||||
AudioService.transcript_tts(app_model=app, text="Test")
|
||||
|
||||
|
||||
class TestAudioServiceTTSVoices:
|
||||
"""Test TTS voice listing operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful retrieval of TTS voices."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
expected_voices = [
|
||||
{"name": "Voice 1", "value": "voice-1"},
|
||||
{"name": "Voice 2", "value": "voice-2"},
|
||||
]
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = expected_voices
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
# Assert
|
||||
assert result == expected_voices
|
||||
mock_model_instance.get_tts_voices.assert_called_once_with(language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices raises error when no model instance is available."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices propagates exceptions from model instance."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Model error"):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
|
@ -0,0 +1,494 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
class TestEndUserServiceFactory:
|
||||
"""Factory class for creating test data and mock objects for end user service tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
name: str = "Test App",
|
||||
) -> MagicMock:
|
||||
"""Create a mock App object."""
|
||||
app = MagicMock(spec=App)
|
||||
app.id = app_id
|
||||
app.tenant_id = tenant_id
|
||||
app.name = name
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-123",
|
||||
session_id: str = "session-001",
|
||||
type: InvokeFrom = InvokeFrom.SERVICE_API,
|
||||
is_anonymous: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock EndUser object."""
|
||||
end_user = MagicMock(spec=EndUser)
|
||||
end_user.id = user_id
|
||||
end_user.tenant_id = tenant_id
|
||||
end_user.app_id = app_id
|
||||
end_user.session_id = session_id
|
||||
end_user.type = type
|
||||
end_user.is_anonymous = is_anonymous
|
||||
end_user.external_user_id = session_id
|
||||
return end_user
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUser:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user method.
|
||||
|
||||
This test suite covers:
|
||||
- Creating new end users
|
||||
- Retrieving existing end users
|
||||
- Default session ID handling
|
||||
- Anonymous user creation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
# Test 01: Get or create with custom user_id
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test getting or creating end user with custom user_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user_id = "custom-user-123"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
# Verify the created user has correct attributes
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.tenant_id == app.tenant_id
|
||||
assert added_user.app_id == app.id
|
||||
assert added_user.session_id == user_id
|
||||
assert added_user.type == InvokeFrom.SERVICE_API
|
||||
assert added_user.is_anonymous is False
|
||||
|
||||
# Test 02: Get or create without user_id (default session)
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test getting or creating end user without user_id uses default session."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert added_user._is_anonymous is True
|
||||
|
||||
# Test 03: Get existing end user
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
|
||||
"""Test retrieving an existing end user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user_id = "existing-user-123"
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
mock_session.add.assert_not_called() # Should not create new user
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user_by_type method.
|
||||
|
||||
This test suite covers:
|
||||
- Creating end users with different InvokeFrom types
|
||||
- Type migration for legacy users
|
||||
- Query ordering and prioritization
|
||||
- Session management
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
# Test 04: Create new end user with SERVICE_API type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating new end user with SERVICE_API type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == InvokeFrom.SERVICE_API
|
||||
assert added_user.tenant_id == tenant_id
|
||||
assert added_user.app_id == app_id
|
||||
assert added_user.session_id == user_id
|
||||
|
||||
# Test 05: Create new end user with WEB_APP type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating new end user with WEB_APP type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.WEB_APP,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == InvokeFrom.WEB_APP
|
||||
|
||||
# Test 06: Upgrade legacy end user type
|
||||
@patch("services.end_user_service.logger")
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
|
||||
"""Test upgrading legacy end user with different type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
# Existing user with old type
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act - Request with different type
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.WEB_APP,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_logger.info.assert_called_once()
|
||||
# Verify log message contains upgrade info
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Upgrading legacy EndUser" in log_call
|
||||
|
||||
# Test 07: Get existing end user with matching type (no upgrade needed)
|
||||
@patch("services.end_user_service.logger")
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
|
||||
"""Test retrieving existing end user with matching type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act - Request with same type
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == InvokeFrom.SERVICE_API
|
||||
# No commit should be called (no type update needed)
|
||||
mock_session.commit.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
# Test 08: Create anonymous user with default session ID
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating anonymous user when user_id is None."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert added_user._is_anonymous is True
|
||||
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
# Test 09: Query ordering prioritizes matching type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test that query ordering prioritizes records with matching type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify order_by was called (for type prioritization)
|
||||
mock_query.order_by.assert_called_once()
|
||||
|
||||
# Test 10: Session context manager properly closes
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_session_context_manager_closes(self, mock_db, mock_session_class, factory):
|
||||
"""Test that Session context manager is properly used."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify context manager was entered and exited
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
# Test 11: External user ID matches session ID
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test that external_user_id is set to match session_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "custom-external-id"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.external_user_id == user_id
|
||||
assert added_user.session_id == user_id
|
||||
|
||||
# Test 12: Different InvokeFrom types
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_type",
|
||||
[
|
||||
InvokeFrom.SERVICE_API,
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.EXPLORE,
|
||||
InvokeFrom.DEBUGGER,
|
||||
],
|
||||
)
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
|
||||
"""Test creating end users with different InvokeFrom types."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=invoke_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == invoke_type
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,649 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.model import App, AppMode, EndUser, Message
|
||||
from services.errors.message import FirstMessageNotExistsError, LastMessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class TestMessageServiceFactory:
|
||||
"""Factory class for creating test data and mock objects for message service tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
mode: str = AppMode.ADVANCED_CHAT.value,
|
||||
name: str = "Test App",
|
||||
) -> MagicMock:
|
||||
"""Create a mock App object."""
|
||||
app = MagicMock(spec=App)
|
||||
app.id = app_id
|
||||
app.mode = mode
|
||||
app.name = name
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(
|
||||
user_id: str = "user-456",
|
||||
session_id: str = "session-789",
|
||||
) -> MagicMock:
|
||||
"""Create a mock EndUser object."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = user_id
|
||||
user.session_id = session_id
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_conversation_mock(
|
||||
conversation_id: str = "conv-001",
|
||||
app_id: str = "app-123",
|
||||
) -> MagicMock:
|
||||
"""Create a mock Conversation object."""
|
||||
conversation = MagicMock()
|
||||
conversation.id = conversation_id
|
||||
conversation.app_id = app_id
|
||||
return conversation
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-001",
|
||||
conversation_id: str = "conv-001",
|
||||
query: str = "What is AI?",
|
||||
answer: str = "AI stands for Artificial Intelligence.",
|
||||
created_at: datetime | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock Message object."""
|
||||
message = MagicMock(spec=Message)
|
||||
message.id = message_id
|
||||
message.conversation_id = conversation_id
|
||||
message.query = query
|
||||
message.answer = answer
|
||||
message.created_at = created_at or datetime.now()
|
||||
return message
|
||||
|
||||
|
||||
class TestMessageServicePaginationByFirstId:
|
||||
"""
|
||||
Unit tests for MessageService.pagination_by_first_id method.
|
||||
|
||||
This test suite covers:
|
||||
- Basic pagination with and without first_id
|
||||
- Order handling (asc/desc)
|
||||
- Edge cases (no user, no conversation, invalid first_id)
|
||||
- Has_more flag logic
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 01: No user provided
|
||||
def test_pagination_by_first_id_no_user(self, factory):
|
||||
"""Test pagination returns empty result when no user is provided."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=None,
|
||||
conversation_id="conv-001",
|
||||
first_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, InfiniteScrollPagination)
|
||||
assert result.data == []
|
||||
assert result.limit == 10
|
||||
assert result.has_more is False
|
||||
|
||||
# Test 02: No conversation_id provided
|
||||
def test_pagination_by_first_id_no_conversation(self, factory):
|
||||
"""Test pagination returns empty result when no conversation_id is provided."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="",
|
||||
first_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, InfiniteScrollPagination)
|
||||
assert result.data == []
|
||||
assert result.limit == 10
|
||||
assert result.has_more is False
|
||||
|
||||
# Test 03: Basic pagination without first_id (desc order)
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_without_first_id_desc(self, mock_conversation_service, mock_db, factory):
|
||||
"""Test basic pagination without first_id in descending order."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
# Create 5 messages
|
||||
messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="conv-001",
|
||||
first_id=None,
|
||||
limit=10,
|
||||
order="desc",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
assert result.limit == 10
|
||||
# Messages should remain in desc order (not reversed)
|
||||
assert result.data[0].id == "msg-000"
|
||||
|
||||
# Test 04: Basic pagination without first_id (asc order)
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_without_first_id_asc(self, mock_conversation_service, mock_db, factory):
|
||||
"""Test basic pagination without first_id in ascending order."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
# Create 5 messages (returned in desc order from DB)
|
||||
messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, 4 - i), # Descending timestamps
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="conv-001",
|
||||
first_id=None,
|
||||
limit=10,
|
||||
order="asc",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
# Messages should be reversed to asc order
|
||||
assert result.data[0].id == "msg-004"
|
||||
assert result.data[4].id == "msg-000"
|
||||
|
||||
# Test 05: Pagination with first_id
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, factory):
|
||||
"""Test pagination with first_id to get messages before a specific message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
first_message = factory.create_message_mock(
|
||||
message_id="msg-005",
|
||||
created_at=datetime(2024, 1, 1, 12, 5),
|
||||
)
|
||||
|
||||
# Messages before first_message
|
||||
history_messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Setup query mocks
|
||||
mock_query_first = MagicMock()
|
||||
mock_query_history = MagicMock()
|
||||
|
||||
def query_side_effect(*args):
|
||||
if args[0] == Message:
|
||||
# First call returns mock for first_message query
|
||||
if not hasattr(query_side_effect, "call_count"):
|
||||
query_side_effect.call_count = 0
|
||||
query_side_effect.call_count += 1
|
||||
|
||||
if query_side_effect.call_count == 1:
|
||||
return mock_query_first
|
||||
else:
|
||||
return mock_query_history
|
||||
|
||||
mock_db.session.query.side_effect = [mock_query_first, mock_query_history]
|
||||
|
||||
# Setup first message query
|
||||
mock_query_first.where.return_value = mock_query_first
|
||||
mock_query_first.first.return_value = first_message
|
||||
|
||||
# Setup history messages query
|
||||
mock_query_history.where.return_value = mock_query_history
|
||||
mock_query_history.order_by.return_value = mock_query_history
|
||||
mock_query_history.limit.return_value = mock_query_history
|
||||
mock_query_history.all.return_value = history_messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="conv-001",
|
||||
first_id="msg-005",
|
||||
limit=10,
|
||||
order="desc",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
mock_query_first.where.assert_called_once()
|
||||
mock_query_history.where.assert_called_once()
|
||||
|
||||
# Test 06: First message not found
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_first_message_not_exists(self, mock_conversation_service, mock_db, factory):
|
||||
"""Test error handling when first_id doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None # Message not found
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="conv-001",
|
||||
first_id="nonexistent-msg",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Test 07: Has_more flag when results exceed limit
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, factory):
|
||||
"""Test has_more flag is True when results exceed limit."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
# Create limit+1 messages (11 messages for limit=10)
|
||||
messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(11)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="conv-001",
|
||||
first_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 10 # Last message trimmed
|
||||
assert result.has_more is True
|
||||
assert result.limit == 10
|
||||
|
||||
# Test 08: Empty conversation
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_empty_conversation(self, mock_conversation_service, mock_db, factory):
|
||||
"""Test pagination with conversation that has no messages."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
conversation_id="conv-001",
|
||||
first_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
assert result.limit == 10
|
||||
|
||||
|
||||
class TestMessageServicePaginationByLastId:
|
||||
"""
|
||||
Unit tests for MessageService.pagination_by_last_id method.
|
||||
|
||||
This test suite covers:
|
||||
- Basic pagination with and without last_id
|
||||
- Conversation filtering
|
||||
- Include_ids filtering
|
||||
- Edge cases (no user, invalid last_id)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestMessageServiceFactory()
|
||||
|
||||
# Test 09: No user provided
|
||||
def test_pagination_by_last_id_no_user(self, factory):
|
||||
"""Test pagination returns empty result when no user is provided."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=None,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, InfiniteScrollPagination)
|
||||
assert result.data == []
|
||||
assert result.limit == 10
|
||||
assert result.has_more is False
|
||||
|
||||
# Test 10: Basic pagination without last_id
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_without_last_id(self, mock_db, factory):
|
||||
"""Test basic pagination without last_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
assert result.limit == 10
|
||||
|
||||
# Test 11: Pagination with last_id
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_with_last_id(self, mock_db, factory):
|
||||
"""Test pagination with last_id to get messages after a specific message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
last_message = factory.create_message_mock(
|
||||
message_id="msg-005",
|
||||
created_at=datetime(2024, 1, 1, 12, 5),
|
||||
)
|
||||
|
||||
# Messages after last_message
|
||||
new_messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(6, 10)
|
||||
]
|
||||
|
||||
# Setup base query mock that returns itself for chaining
|
||||
mock_base_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_base_query
|
||||
|
||||
# First where() call for last_id lookup
|
||||
mock_query_last = MagicMock()
|
||||
mock_query_last.first.return_value = last_message
|
||||
|
||||
# Second where() call for history messages
|
||||
mock_query_history = MagicMock()
|
||||
mock_query_history.order_by.return_value = mock_query_history
|
||||
mock_query_history.limit.return_value = mock_query_history
|
||||
mock_query_history.all.return_value = new_messages
|
||||
|
||||
# Setup where() to return different mocks on consecutive calls
|
||||
mock_base_query.where.side_effect = [mock_query_last, mock_query_history]
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id="msg-005",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 4
|
||||
assert result.has_more is False
|
||||
|
||||
# Test 12: Last message not found
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_last_message_not_exists(self, mock_db, factory):
|
||||
"""Test error handling when last_id doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None # Message not found
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(LastMessageNotExistsError):
|
||||
MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id="nonexistent-msg",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Test 13: Pagination with conversation_id filter
|
||||
@patch("services.message_service.ConversationService")
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_with_conversation_filter(self, mock_db, mock_conversation_service, factory):
|
||||
"""Test pagination filtered by conversation_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
conversation = factory.create_conversation_mock(conversation_id="conv-001")
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
conversation_id="conv-001",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
conversation_id="conv-001",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
# Verify conversation_id was used in query
|
||||
mock_query.where.assert_called()
|
||||
mock_conversation_service.get_conversation.assert_called_once()
|
||||
|
||||
# Test 14: Pagination with include_ids filter
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_with_include_ids(self, mock_db, factory):
|
||||
"""Test pagination filtered by include_ids."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
# Only messages with IDs in include_ids should be returned
|
||||
messages = [
|
||||
factory.create_message_mock(message_id="msg-001"),
|
||||
factory.create_message_mock(message_id="msg-003"),
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
include_ids=["msg-001", "msg-003"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "msg-001"
|
||||
assert result.data[1].id == "msg-003"
|
||||
|
||||
# Test 15: Has_more flag when results exceed limit
|
||||
@patch("services.message_service.db")
|
||||
def test_pagination_by_last_id_has_more_true(self, mock_db, factory):
|
||||
"""Test has_more flag is True when results exceed limit."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
# Create limit+1 messages (11 messages for limit=10)
|
||||
messages = [
|
||||
factory.create_message_mock(
|
||||
message_id=f"msg-{i:03d}",
|
||||
created_at=datetime(2024, 1, 1, 12, i),
|
||||
)
|
||||
for i in range(11)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 10 # Last message trimmed
|
||||
assert result.has_more is True
|
||||
assert result.limit == 10
|
||||
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Comprehensive unit tests for RecommendedAppService.
|
||||
|
||||
This test suite provides complete coverage of recommended app operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
|
||||
Tests fetching recommended apps with categories:
|
||||
- Successful retrieval with recommended apps
|
||||
- Fallback to builtin when no recommended apps
|
||||
- Different language support
|
||||
- Factory mode selection (remote, builtin, db)
|
||||
- Empty result handling
|
||||
|
||||
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
|
||||
Tests fetching individual app details:
|
||||
- Successful app detail retrieval
|
||||
- Different factory modes
|
||||
- App not found scenarios
|
||||
- Language-specific details
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
|
||||
are mocked for fast, isolated unit tests
|
||||
- **Factory Pattern**: Tests verify correct factory selection based on mode
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and factory method calls
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Factory Modes:**
|
||||
- remote: Fetch from remote API
|
||||
- builtin: Use built-in templates
|
||||
- db: Fetch from database
|
||||
|
||||
**Fallback Logic:**
|
||||
- If remote/db returns no apps, fallback to builtin en-US templates
|
||||
- Ensures users always see some recommended apps
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
class RecommendedAppServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
recommended app operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_recommended_apps_response(
|
||||
recommended_apps: list[dict] | None = None,
|
||||
categories: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for recommended apps.
|
||||
|
||||
Args:
|
||||
recommended_apps: List of recommended app dictionaries
|
||||
categories: List of category names
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended_apps and categories
|
||||
"""
|
||||
if recommended_apps is None:
|
||||
recommended_apps = [
|
||||
{
|
||||
"id": "app-1",
|
||||
"name": "Test App 1",
|
||||
"description": "Test description 1",
|
||||
"category": "productivity",
|
||||
},
|
||||
{
|
||||
"id": "app-2",
|
||||
"name": "Test App 2",
|
||||
"description": "Test description 2",
|
||||
"category": "communication",
|
||||
},
|
||||
]
|
||||
if categories is None:
|
||||
categories = ["productivity", "communication", "utilities"]
|
||||
|
||||
return {
|
||||
"recommended_apps": recommended_apps,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_app_detail_response(
|
||||
app_id: str = "app-123",
|
||||
name: str = "Test App",
|
||||
description: str = "Test description",
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for app detail.
|
||||
|
||||
Args:
|
||||
app_id: App identifier
|
||||
name: App name
|
||||
description: App description
|
||||
**kwargs: Additional fields
|
||||
|
||||
Returns:
|
||||
Dictionary with app details
|
||||
"""
|
||||
detail = {
|
||||
"id": app_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": kwargs.get("category", "productivity"),
|
||||
"icon": kwargs.get("icon", "🚀"),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
}
|
||||
detail.update(kwargs)
|
||||
return detail
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return RecommendedAppServiceTestDataFactory
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
"""Test get_recommended_apps_and_categories operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of recommended apps when apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
expected_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock factory and retrieval instance
|
||||
mock_retrieval_instance = MagicMock()
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_retrieval_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
assert len(result["categories"]) == 3
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback to builtin when no recommended apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
# Remote returns empty recommended_apps
|
||||
empty_response = {"recommended_apps": [], "categories": []}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
|
||||
)
|
||||
|
||||
# Mock remote retrieval instance (returns empty)
|
||||
mock_remote_instance = MagicMock()
|
||||
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
|
||||
|
||||
mock_remote_factory = MagicMock()
|
||||
mock_remote_factory.return_value = mock_remote_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
assert len(result["recommended_apps"]) == 1
|
||||
assert result["recommended_apps"][0]["id"] == "builtin-1"
|
||||
# Verify fallback was called with en-US (hardcoded)
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback when recommended_apps key is None."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
|
||||
|
||||
# Response with None recommended_apps
|
||||
none_response = {"recommended_apps": None, "categories": ["test"]}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock db retrieval instance (returns None)
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
|
||||
|
||||
mock_db_factory = MagicMock()
|
||||
mock_db_factory.return_value = mock_db_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
|
||||
"""Test retrieval with different language codes."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
|
||||
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
|
||||
|
||||
for language in languages:
|
||||
# Create language-specific response
|
||||
lang_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(language)
|
||||
|
||||
# Assert
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that correct factory is selected based on mode."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
"""Test get_recommend_app_detail operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of app detail."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "app-123"
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Productivity App",
|
||||
description="A great productivity app",
|
||||
category="productivity",
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_detail
|
||||
assert result["id"] == app_id
|
||||
assert result["name"] == "Productivity App"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail retrieval with different factory modes."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
app_id = "test-app"
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that None is returned when app is not found."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "nonexistent-app"
|
||||
|
||||
# Mock retrieval instance returning None
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = None
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
|
||||
"""Test handling of empty dict response."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
app_id = "app-empty"
|
||||
|
||||
# Mock retrieval instance returning empty dict
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = {}
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail with complex model configuration."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "complex-app"
|
||||
|
||||
complex_model_config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"parameters": {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Complex App",
|
||||
model_config=complex_model_config,
|
||||
workflows=["workflow-1", "workflow-2"],
|
||||
tools=["tool-1", "tool-2", "tool-3"],
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result["model_config"] == complex_model_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
|
|
@ -0,0 +1,626 @@
|
|||
"""
|
||||
Comprehensive unit tests for SavedMessageService.
|
||||
|
||||
This test suite provides complete coverage of saved message operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Pagination (TestSavedMessageServicePagination)
|
||||
Tests saved message listing and pagination:
|
||||
- Pagination with valid user (Account and EndUser)
|
||||
- Pagination without user raises ValueError
|
||||
- Pagination with last_id parameter
|
||||
- Empty results when no saved messages exist
|
||||
- Integration with MessageService pagination
|
||||
|
||||
### 2. Save Operations (TestSavedMessageServiceSave)
|
||||
Tests saving messages:
|
||||
- Save message for Account user
|
||||
- Save message for EndUser
|
||||
- Save without user (no-op)
|
||||
- Prevent duplicate saves (idempotent)
|
||||
- Message validation through MessageService
|
||||
|
||||
### 3. Delete Operations (TestSavedMessageServiceDelete)
|
||||
Tests deleting saved messages:
|
||||
- Delete saved message for Account user
|
||||
- Delete saved message for EndUser
|
||||
- Delete without user (no-op)
|
||||
- Delete non-existent saved message (no-op)
|
||||
- Proper database cleanup
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked
|
||||
for fast, isolated unit tests
|
||||
- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and side effects
|
||||
(database operations, method calls)
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**User Types:**
|
||||
- Account: Workspace members (console users)
|
||||
- EndUser: API users (end users)
|
||||
|
||||
**Saved Messages:**
|
||||
- Users can save messages for later reference
|
||||
- Each user has their own saved message list
|
||||
- Saving is idempotent (duplicate saves ignored)
|
||||
- Deletion is safe (non-existent deletes ignored)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.model import App, EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
saved message operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock Account object.
|
||||
|
||||
Args:
|
||||
account_id: Unique identifier for the account
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Account object with specified attributes
|
||||
"""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock EndUser object.
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the end user
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock EndUser object with specified attributes
|
||||
"""
|
||||
user = create_autospec(EndUser, instance=True)
|
||||
user.id = user_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock App object.
|
||||
|
||||
Args:
|
||||
app_id: Unique identifier for the app
|
||||
tenant_id: Tenant/workspace identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock App object with specified attributes
|
||||
"""
|
||||
app = create_autospec(App, instance=True)
|
||||
app.id = app_id
|
||||
app.tenant_id = tenant_id
|
||||
app.name = kwargs.get("name", "Test App")
|
||||
app.mode = kwargs.get("mode", "chat")
|
||||
for key, value in kwargs.items():
|
||||
setattr(app, key, value)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-123",
|
||||
app_id: str = "app-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Message object.
|
||||
|
||||
Args:
|
||||
message_id: Unique identifier for the message
|
||||
app_id: Associated app identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Message object with specified attributes
|
||||
"""
|
||||
message = create_autospec(Message, instance=True)
|
||||
message.id = message_id
|
||||
message.app_id = app_id
|
||||
message.query = kwargs.get("query", "Test query")
|
||||
message.answer = kwargs.get("answer", "Test answer")
|
||||
message.created_at = kwargs.get("created_at", datetime.now(UTC))
|
||||
for key, value in kwargs.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def create_saved_message_mock(
|
||||
saved_message_id: str = "saved-123",
|
||||
app_id: str = "app-123",
|
||||
message_id: str = "msg-123",
|
||||
created_by: str = "user-123",
|
||||
created_by_role: str = "account",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock SavedMessage object.
|
||||
|
||||
Args:
|
||||
saved_message_id: Unique identifier for the saved message
|
||||
app_id: Associated app identifier
|
||||
message_id: Associated message identifier
|
||||
created_by: User who saved the message
|
||||
created_by_role: Role of the user ('account' or 'end_user')
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock SavedMessage object with specified attributes
|
||||
"""
|
||||
saved_message = create_autospec(SavedMessage, instance=True)
|
||||
saved_message.id = saved_message_id
|
||||
saved_message.app_id = app_id
|
||||
saved_message.message_id = message_id
|
||||
saved_message.created_by = created_by
|
||||
saved_message.created_by_role = created_by_role
|
||||
saved_message.created_at = kwargs.get("created_at", datetime.now(UTC))
|
||||
for key, value in kwargs.items():
|
||||
setattr(saved_message, key, value)
|
||||
return saved_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return SavedMessageServiceTestDataFactory
|
||||
|
||||
|
||||
class TestSavedMessageServicePagination:
|
||||
"""Test saved message pagination operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
|
||||
# Create saved messages for this user
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
saved_message_id=f"saved-{i}",
|
||||
app_id=app.id,
|
||||
message_id=f"msg-{i}",
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
mock_db_session.query.assert_called_once_with(SavedMessage)
|
||||
# Verify MessageService was called with correct message IDs
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
include_ids=["msg-0", "msg-1", "msg-2"],
|
||||
)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
# Create saved messages for this end user
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
saved_message_id=f"saved-{i}",
|
||||
app_id=app.id,
|
||||
message_id=f"msg-{i}",
|
||||
created_by=user.id,
|
||||
created_by_role="end_user",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify correct role was used in query
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
include_ids=["msg-0", "msg-1"],
|
||||
)
|
||||
|
||||
def test_pagination_without_user_raises_error(self, factory):
|
||||
"""Test that pagination without user raises ValueError."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with last_id parameter."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
last_id = "msg-last"
|
||||
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
message_id=f"msg-{i}",
|
||||
app_id=app.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify last_id was passed to MessageService
|
||||
mock_message_pagination.assert_called_once()
|
||||
call_args = mock_message_pagination.call_args
|
||||
assert call_args.kwargs["last_id"] == last_id
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination when user has no saved messages."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
|
||||
# Mock database query returning empty list
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify MessageService was called with empty include_ids
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
include_ids=[],
|
||||
)
|
||||
|
||||
|
||||
class TestSavedMessageServiceSave:
|
||||
"""Test save message operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message = factory.create_message_mock(message_id="msg-123", app_id=app.id)
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
saved_message = mock_db_session.add.call_args[0][0]
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by == user.id
|
||||
assert saved_message.created_by_role == "account"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock(message_id="msg-456", app_id=app.id)
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
saved_message = mock_db_session.add.call_args[0][0]
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by == user.id
|
||||
assert saved_message.created_by_role == "end_user"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that saving without user is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=None, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that saving an already saved message is idempotent."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-789"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
existing_saved = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = existing_saved
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert - no new saved message created
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
mock_get_message.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that save validates message exists through MessageService."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message = factory.create_message_mock()
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert - MessageService.get_message was called for validation
|
||||
mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
|
||||
class TestSavedMessageServiceDelete:
|
||||
"""Test delete saved message operations."""
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_saved_message_for_account(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-123"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message_id = "msg-456"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="end_user",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting without user is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=None, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting a non-existent saved message is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-nonexistent"
|
||||
|
||||
# Mock database query - no saved message found
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert - no deletion occurred
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
|
||||
"""Test that delete only removes the user's own saved message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user1 = factory.create_account_mock(account_id="user-1")
|
||||
message_id = "msg-shared"
|
||||
|
||||
# Mock database query - finds user1's saved message
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user1.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user1, message_id=message_id)
|
||||
|
||||
# Assert - only user1's saved message is deleted
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
# Verify the query filters by user
|
||||
assert mock_query.where.called
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,9 +1,9 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
|
|
@ -299,3 +299,154 @@ class TestToolTransformService:
|
|||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
11
api/uv.lock
11
api/uv.lock
|
|
@ -1628,7 +1628,7 @@ dev = [
|
|||
{ name = "celery-types", specifier = ">=0.23.0" },
|
||||
{ name = "coverage", specifier = "~=7.2.4" },
|
||||
{ name = "dotenv-linter", specifier = "~=0.5.0" },
|
||||
{ name = "faker", specifier = "~=32.1.0" },
|
||||
{ name = "faker", specifier = "~=38.2.0" },
|
||||
{ name = "hypothesis", specifier = ">=6.131.15" },
|
||||
{ name = "import-linter", specifier = ">=2.3" },
|
||||
{ name = "lxml-stubs", specifier = "~=0.5.1" },
|
||||
|
|
@ -1859,15 +1859,14 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "faker"
|
||||
version = "32.1.0"
|
||||
version = "38.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "tzdata" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1c/2a/dd2c8f55d69013d0eee30ec4c998250fb7da957f5fe860ed077b3df1725b/faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5", size = 1850193, upload-time = "2024-11-12T22:04:34.812Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/27/022d4dbd4c20567b4c294f79a133cc2f05240ea61e0d515ead18c995c249/faker-38.2.0.tar.gz", hash = "sha256:20672803db9c7cb97f9b56c18c54b915b6f1d8991f63d1d673642dc43f5ce7ab", size = 1941469, upload-time = "2025-11-19T16:37:31.892Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/fa/4a82dea32d6262a96e6841cdd4a45c11ac09eecdff018e745565410ac70e/Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814", size = 1889123, upload-time = "2024-11-12T22:04:32.298Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/93/00c94d45f55c336434a15f98d906387e87ce28f9918e4444829a8fda432d/faker-38.2.0-py3-none-any.whl", hash = "sha256:35fe4a0a79dee0dc4103a6083ee9224941e7d3594811a50e3969e547b0d2ee65", size = 1980505, upload-time = "2025-11-19T16:37:30.208Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.4.0-local
|
||||
image: langgenius/dify-plugin-daemon:0.4.1-local
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import useAccessControlStore from '@/context/access-control-store'
|
||||
import type { AccessMode } from '@/models/access-control'
|
||||
|
||||
type AccessControlItemProps = PropsWithChildren<{
|
||||
|
|
@ -8,7 +8,8 @@ type AccessControlItemProps = PropsWithChildren<{
|
|||
}>
|
||||
|
||||
const AccessControlItem: FC<AccessControlItemProps> = ({ type, children }) => {
|
||||
const { currentMenu, setCurrentMenu } = useAccessControlStore(s => ({ currentMenu: s.currentMenu, setCurrentMenu: s.setCurrentMenu }))
|
||||
const currentMenu = useAccessControlStore(s => s.currentMenu)
|
||||
const setCurrentMenu = useAccessControlStore(s => s.setCurrentMenu)
|
||||
if (currentMenu !== type) {
|
||||
return <div
|
||||
className="cursor-pointer rounded-[10px] border-[1px]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ const Empty = () => {
|
|||
return (
|
||||
<>
|
||||
<DefaultCards />
|
||||
<div className='absolute bottom-0 left-0 right-0 top-0 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'>
|
||||
<div className='absolute inset-0 z-20 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent pointer-events-none'>
|
||||
<span className='system-md-medium text-text-tertiary'>
|
||||
{t('app.newApp.noAppsFound')}
|
||||
</span>
|
||||
|
|
|
|||
|
|
@ -187,6 +187,19 @@ const GotoAnything: FC<Props> = ({
|
|||
}, {} as { [key: string]: SearchResult[] }),
|
||||
[searchResults])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCommandsMode)
|
||||
return
|
||||
|
||||
if (!searchResults.length)
|
||||
return
|
||||
|
||||
const currentValueExists = searchResults.some(result => `${result.type}-${result.id}` === cmdVal)
|
||||
|
||||
if (!currentValueExists)
|
||||
setCmdVal(`${searchResults[0].type}-${searchResults[0].id}`)
|
||||
}, [isCommandsMode, searchResults, cmdVal])
|
||||
|
||||
const emptyResult = useMemo(() => {
|
||||
if (searchResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
|
||||
return null
|
||||
|
|
@ -386,7 +399,7 @@ const GotoAnything: FC<Props> = ({
|
|||
<Command.Item
|
||||
key={`${result.type}-${result.id}`}
|
||||
value={`${result.type}-${result.id}`}
|
||||
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover'
|
||||
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt'
|
||||
onSelect={() => handleNavigate(result)}
|
||||
>
|
||||
{result.icon}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,12 @@ const Nav = ({
|
|||
`}>
|
||||
<Link href={link + (linkLastSearchParams && `?${linkLastSearchParams}`)}>
|
||||
<div
|
||||
onClick={() => setAppDetail()}
|
||||
onClick={(e) => {
|
||||
// Don't clear state if opening in new tab/window
|
||||
if (e.metaKey || e.ctrlKey || e.shiftKey || e.button !== 0)
|
||||
return
|
||||
setAppDetail()
|
||||
}}
|
||||
className={classNames(
|
||||
'flex h-7 cursor-pointer items-center rounded-[10px] px-2.5',
|
||||
isActivated ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text',
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ export type Collection = {
|
|||
timeout?: number
|
||||
sse_read_timeout?: number
|
||||
}
|
||||
// Workflow
|
||||
workflow_app_id?: string
|
||||
}
|
||||
|
||||
export type ToolParameter = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useEdges } from 'reactflow'
|
||||
|
|
@ -16,6 +17,10 @@ import {
|
|||
} from '@/app/components/workflow/hooks'
|
||||
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { useAllWorkflowTools } from '@/service/use-tools'
|
||||
import { canFindTool } from '@/utils'
|
||||
|
||||
type PanelOperatorPopupProps = {
|
||||
id: string
|
||||
|
|
@ -45,6 +50,14 @@ const PanelOperatorPopup = ({
|
|||
const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly
|
||||
const isChildNode = !!(data.isInIteration || data.isInLoop)
|
||||
|
||||
const { data: workflowTools } = useAllWorkflowTools()
|
||||
const isWorkflowTool = data.type === BlockEnum.Tool && data.provider_type === CollectionType.workflow
|
||||
const workflowAppId = useMemo(() => {
|
||||
if (!isWorkflowTool || !workflowTools || !data.provider_id) return undefined
|
||||
const workflowTool = workflowTools.find(item => canFindTool(item.id, data.provider_id))
|
||||
return workflowTool?.workflow_app_id
|
||||
}, [isWorkflowTool, workflowTools, data.provider_id])
|
||||
|
||||
return (
|
||||
<div className='w-[240px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl'>
|
||||
{
|
||||
|
|
@ -137,6 +150,22 @@ const PanelOperatorPopup = ({
|
|||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
isWorkflowTool && workflowAppId && (
|
||||
<>
|
||||
<div className='p-1'>
|
||||
<a
|
||||
href={`/app/${workflowAppId}/workflow`}
|
||||
target='_blank'
|
||||
className='flex h-8 cursor-pointer items-center rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover'
|
||||
>
|
||||
{t('workflow.panel.openWorkflow')}
|
||||
</a>
|
||||
</div>
|
||||
<div className='h-px bg-divider-regular'></div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
showHelpLink && nodeMetaData.helpLinkUri && (
|
||||
<>
|
||||
|
|
|
|||
|
|
@ -47,10 +47,8 @@ const ChatWrapper = (
|
|||
const startVariables = startNode?.data.variables
|
||||
const appDetail = useAppStore(s => s.appDetail)
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { inputs, setInputs } = useStore(s => ({
|
||||
inputs: s.inputs,
|
||||
setInputs: s.setInputs,
|
||||
}))
|
||||
const inputs = useStore(s => s.inputs)
|
||||
const setInputs = useStore(s => s.setInputs)
|
||||
|
||||
const initialInputs = useMemo(() => {
|
||||
const initInputs: Record<string, any> = {}
|
||||
|
|
|
|||
|
|
@ -32,10 +32,7 @@ type Props = {
|
|||
const InputsPanel = ({ onRun }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { inputs } = useStore(s => ({
|
||||
inputs: s.inputs,
|
||||
setInputs: s.setInputs,
|
||||
}))
|
||||
const inputs = useStore(s => s.inputs)
|
||||
const fileSettings = useHooksStore(s => s.configsMap?.fileSettings)
|
||||
const nodes = useNodes<StartNodeType>()
|
||||
const files = useStore(s => s.files)
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ describe('useTabSearchParams', () => {
|
|||
setActiveTab('settings')
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings', { scroll: false })
|
||||
expect(mockReplace).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ describe('useTabSearchParams', () => {
|
|||
setActiveTab('settings')
|
||||
})
|
||||
|
||||
expect(mockReplace).toHaveBeenCalledWith('/test-path?category=settings')
|
||||
expect(mockReplace).toHaveBeenCalledWith('/test-path?category=settings', { scroll: false })
|
||||
expect(mockPush).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
|
|
@ -157,6 +157,7 @@ describe('useTabSearchParams', () => {
|
|||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
'/test-path?category=settings%20%26%20config',
|
||||
{ scroll: false },
|
||||
)
|
||||
})
|
||||
|
||||
|
|
@ -211,7 +212,7 @@ describe('useTabSearchParams', () => {
|
|||
setActiveTab('profile')
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?tab=profile')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?tab=profile', { scroll: false })
|
||||
})
|
||||
})
|
||||
|
||||
|
|
@ -294,7 +295,7 @@ describe('useTabSearchParams', () => {
|
|||
|
||||
const [activeTab] = result.current
|
||||
expect(activeTab).toBe('')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=', { scroll: false })
|
||||
})
|
||||
|
||||
/**
|
||||
|
|
@ -345,7 +346,7 @@ describe('useTabSearchParams', () => {
|
|||
setActiveTab('settings')
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/fallback-path?category=settings')
|
||||
expect(mockPush).toHaveBeenCalledWith('/fallback-path?category=settings', { scroll: false })
|
||||
|
||||
// Restore mock
|
||||
;(usePathname as jest.Mock).mockReturnValue(mockPathname)
|
||||
|
|
@ -400,7 +401,7 @@ describe('useTabSearchParams', () => {
|
|||
})
|
||||
|
||||
expect(result.current[0]).toBe('settings')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=settings', { scroll: false })
|
||||
|
||||
// Change to profile tab
|
||||
act(() => {
|
||||
|
|
@ -409,7 +410,7 @@ describe('useTabSearchParams', () => {
|
|||
})
|
||||
|
||||
expect(result.current[0]).toBe('profile')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=profile')
|
||||
expect(mockPush).toHaveBeenCalledWith('/test-path?category=profile', { scroll: false })
|
||||
|
||||
// Verify push was called twice
|
||||
expect(mockPush).toHaveBeenCalledTimes(2)
|
||||
|
|
@ -431,7 +432,7 @@ describe('useTabSearchParams', () => {
|
|||
setActiveTab('advanced')
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/app/123/settings?category=advanced')
|
||||
expect(mockPush).toHaveBeenCalledWith('/app/123/settings?category=advanced', { scroll: false })
|
||||
|
||||
// Restore mock
|
||||
;(usePathname as jest.Mock).mockReturnValue(mockPathname)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ export const useTabSearchParams = ({
|
|||
setTab(newActiveTab)
|
||||
if (disableSearchParams)
|
||||
return
|
||||
router[`${routingBehavior}`](`${pathName}?${searchParamName}=${encodeURIComponent(newActiveTab)}`)
|
||||
router[`${routingBehavior}`](`${pathName}?${searchParamName}=${encodeURIComponent(newActiveTab)}`, { scroll: false })
|
||||
}
|
||||
|
||||
return [activeTab, setActiveTab] as const
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(optional & hidden)',
|
||||
goTo: 'Gehe zu',
|
||||
startNode: 'Startknoten',
|
||||
openWorkflow: 'Workflow öffnen',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -400,6 +400,7 @@ const translation = {
|
|||
userInputField: 'User Input Field',
|
||||
changeBlock: 'Change Node',
|
||||
helpLink: 'View Docs',
|
||||
openWorkflow: 'Open Workflow',
|
||||
about: 'About',
|
||||
createdBy: 'Created By ',
|
||||
nextStep: 'Next Step',
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opcional y oculto)',
|
||||
goTo: 'Ir a',
|
||||
startNode: 'Nodo de inicio',
|
||||
openWorkflow: 'Abrir flujo de trabajo',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(اختیاری و پنهان)',
|
||||
goTo: 'برو به',
|
||||
startNode: 'گره شروع',
|
||||
openWorkflow: 'باز کردن جریان کاری',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(optionnel et caché)',
|
||||
goTo: 'Aller à',
|
||||
startNode: 'Nœud de départ',
|
||||
openWorkflow: 'Ouvrir le flux de travail',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -386,6 +386,7 @@ const translation = {
|
|||
optional_and_hidden: '(वैकल्पिक और छिपा हुआ)',
|
||||
goTo: 'जाओ',
|
||||
startNode: 'प्रारंभ नोड',
|
||||
openWorkflow: 'वर्कफ़्लो खोलें',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -381,6 +381,7 @@ const translation = {
|
|||
goTo: 'Pergi ke',
|
||||
startNode: 'Mulai Node',
|
||||
scrollToSelectedNode: 'Gulir ke node yang dipilih',
|
||||
openWorkflow: 'Buka Alur Kerja',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -389,6 +389,7 @@ const translation = {
|
|||
optional_and_hidden: '(opzionale e nascosto)',
|
||||
goTo: 'Vai a',
|
||||
startNode: 'Nodo iniziale',
|
||||
openWorkflow: 'Apri flusso di lavoro',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -401,6 +401,7 @@ const translation = {
|
|||
minimize: '全画面を終了する',
|
||||
scrollToSelectedNode: '選択したノードまでスクロール',
|
||||
optional_and_hidden: '(オプションおよび非表示)',
|
||||
openWorkflow: 'ワークフローを開く',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -395,6 +395,7 @@ const translation = {
|
|||
optional_and_hidden: '(선택 사항 및 숨김)',
|
||||
goTo: '로 이동',
|
||||
startNode: '시작 노드',
|
||||
openWorkflow: '워크플로 열기',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opcjonalne i ukryte)',
|
||||
goTo: 'Idź do',
|
||||
startNode: 'Węzeł początkowy',
|
||||
openWorkflow: 'Otwórz przepływ pracy',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opcional & oculto)',
|
||||
goTo: 'Ir para',
|
||||
startNode: 'Iniciar Nó',
|
||||
openWorkflow: 'Abrir fluxo de trabalho',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opțional și ascuns)',
|
||||
goTo: 'Du-te la',
|
||||
startNode: 'Nod de start',
|
||||
openWorkflow: 'Deschide fluxul de lucru',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(необязательно и скрыто)',
|
||||
goTo: 'Перейти к',
|
||||
startNode: 'Начальный узел',
|
||||
openWorkflow: 'Открыть рабочий процесс',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -381,6 +381,7 @@ const translation = {
|
|||
optional_and_hidden: '(neobvezno in skrito)',
|
||||
goTo: 'Pojdi na',
|
||||
startNode: 'Začetni vozel',
|
||||
openWorkflow: 'Odpri delovni tok',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(ตัวเลือก & ซ่อน)',
|
||||
goTo: 'ไปที่',
|
||||
startNode: 'เริ่มต้นโหนด',
|
||||
openWorkflow: 'เปิดเวิร์กโฟลว์',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(isteğe bağlı ve gizli)',
|
||||
goTo: 'Git',
|
||||
startNode: 'Başlangıç Düğümü',
|
||||
openWorkflow: 'İş Akışını Aç',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(необов\'язково & приховано)',
|
||||
goTo: 'Перейти до',
|
||||
startNode: 'Початковий вузол',
|
||||
openWorkflow: 'Відкрити робочий процес',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(tùy chọn & ẩn)',
|
||||
goTo: 'Đi tới',
|
||||
startNode: 'Nút Bắt đầu',
|
||||
openWorkflow: 'Mở quy trình làm việc',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -400,6 +400,7 @@ const translation = {
|
|||
userInputField: '用户输入字段',
|
||||
changeBlock: '更改节点',
|
||||
helpLink: '查看帮助文档',
|
||||
openWorkflow: '打开工作流',
|
||||
about: '关于',
|
||||
createdBy: '作者',
|
||||
nextStep: '下一步',
|
||||
|
|
|
|||
|
|
@ -379,6 +379,7 @@ const translation = {
|
|||
optional_and_hidden: '(可選且隱藏)',
|
||||
goTo: '前往',
|
||||
startNode: '起始節點',
|
||||
openWorkflow: '打開工作流程',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -53,10 +53,10 @@
|
|||
"@hookform/resolvers": "^3.10.0",
|
||||
"@lexical/code": "^0.36.2",
|
||||
"@lexical/link": "^0.36.2",
|
||||
"@lexical/list": "^0.36.2",
|
||||
"@lexical/list": "^0.38.2",
|
||||
"@lexical/react": "^0.36.2",
|
||||
"@lexical/selection": "^0.37.0",
|
||||
"@lexical/text": "^0.36.2",
|
||||
"@lexical/text": "^0.38.2",
|
||||
"@lexical/utils": "^0.37.0",
|
||||
"@monaco-editor/react": "^4.7.0",
|
||||
"@octokit/core": "^6.1.6",
|
||||
|
|
@ -79,7 +79,7 @@
|
|||
"decimal.js": "^10.6.0",
|
||||
"dompurify": "^3.3.0",
|
||||
"echarts": "^5.6.0",
|
||||
"echarts-for-react": "^3.0.2",
|
||||
"echarts-for-react": "^3.0.5",
|
||||
"elkjs": "^0.9.3",
|
||||
"emoji-mart": "^5.6.0",
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
|
|
@ -141,7 +141,7 @@
|
|||
"uuid": "^10.0.0",
|
||||
"zod": "^3.25.76",
|
||||
"zundo": "^2.3.0",
|
||||
"zustand": "^4.5.7"
|
||||
"zustand": "^5.0.9"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@antfu/eslint-config": "^5.4.1",
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ importers:
|
|||
specifier: ^0.36.2
|
||||
version: 0.36.2
|
||||
'@lexical/list':
|
||||
specifier: ^0.36.2
|
||||
version: 0.36.2
|
||||
specifier: ^0.38.2
|
||||
version: 0.38.2
|
||||
'@lexical/react':
|
||||
specifier: ^0.36.2
|
||||
version: 0.36.2(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(yjs@13.6.27)
|
||||
|
|
@ -94,8 +94,8 @@ importers:
|
|||
specifier: ^0.37.0
|
||||
version: 0.37.0
|
||||
'@lexical/text':
|
||||
specifier: ^0.36.2
|
||||
version: 0.36.2
|
||||
specifier: ^0.38.2
|
||||
version: 0.38.2
|
||||
'@lexical/utils':
|
||||
specifier: ^0.37.0
|
||||
version: 0.37.0
|
||||
|
|
@ -163,8 +163,8 @@ importers:
|
|||
specifier: ^5.6.0
|
||||
version: 5.6.0
|
||||
echarts-for-react:
|
||||
specifier: ^3.0.2
|
||||
version: 3.0.2(echarts@5.6.0)(react@19.1.1)
|
||||
specifier: ^3.0.5
|
||||
version: 3.0.5(echarts@5.6.0)(react@19.1.1)
|
||||
elkjs:
|
||||
specifier: ^0.9.3
|
||||
version: 0.9.3
|
||||
|
|
@ -347,10 +347,10 @@ importers:
|
|||
version: 3.25.76
|
||||
zundo:
|
||||
specifier: ^2.3.0
|
||||
version: 2.3.0(zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1))
|
||||
version: 2.3.0(zustand@5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1)))
|
||||
zustand:
|
||||
specifier: ^4.5.7
|
||||
version: 4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)
|
||||
specifier: ^5.0.9
|
||||
version: 5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1))
|
||||
devDependencies:
|
||||
'@antfu/eslint-config':
|
||||
specifier: ^5.4.1
|
||||
|
|
@ -2009,6 +2009,9 @@ packages:
|
|||
'@lexical/clipboard@0.37.0':
|
||||
resolution: {integrity: sha512-hRwASFX/ilaI5r8YOcZuQgONFshRgCPfdxfofNL7uruSFYAO6LkUhsjzZwUgf0DbmCJmbBADFw15FSthgCUhGA==}
|
||||
|
||||
'@lexical/clipboard@0.38.2':
|
||||
resolution: {integrity: sha512-dDShUplCu8/o6BB9ousr3uFZ9bltR+HtleF/Tl8FXFNPpZ4AXhbLKUoJuucRuIr+zqT7RxEv/3M6pk/HEoE6NQ==}
|
||||
|
||||
'@lexical/code@0.36.2':
|
||||
resolution: {integrity: sha512-dfS62rNo3uKwNAJQ39zC+8gYX0k8UAoW7u+JPIqx+K2VPukZlvpsPLNGft15pdWBkHc7Pv+o9gJlB6gGv+EBfA==}
|
||||
|
||||
|
|
@ -2027,6 +2030,9 @@ packages:
|
|||
'@lexical/extension@0.37.0':
|
||||
resolution: {integrity: sha512-Z58f2tIdz9bn8gltUu5cVg37qROGha38dUZv20gI2GeNugXAkoPzJYEcxlI1D/26tkevJ/7VaFUr9PTk+iKmaA==}
|
||||
|
||||
'@lexical/extension@0.38.2':
|
||||
resolution: {integrity: sha512-qbUNxEVjAC0kxp7hEMTzktj0/51SyJoIJWK6Gm790b4yNBq82fEPkksfuLkRg9VQUteD0RT1Nkjy8pho8nNamw==}
|
||||
|
||||
'@lexical/hashtag@0.36.2':
|
||||
resolution: {integrity: sha512-WdmKtzXFcahQT3ShFDeHF6LCR5C8yvFCj3ImI09rZwICrYeonbMrzsBUxS1joBz0HQ+ufF9Tx+RxLvGWx6WxzQ==}
|
||||
|
||||
|
|
@ -2039,6 +2045,9 @@ packages:
|
|||
'@lexical/html@0.37.0':
|
||||
resolution: {integrity: sha512-oTsBc45eL8/lmF7fqGR+UCjrJYP04gumzf5nk4TczrxWL2pM4GIMLLKG1mpQI2H1MDiRLzq3T/xdI7Gh74z7Zw==}
|
||||
|
||||
'@lexical/html@0.38.2':
|
||||
resolution: {integrity: sha512-pC5AV+07bmHistRwgG3NJzBMlIzSdxYO6rJU4eBNzyR4becdiLsI4iuv+aY7PhfSv+SCs7QJ9oc4i5caq48Pkg==}
|
||||
|
||||
'@lexical/link@0.36.2':
|
||||
resolution: {integrity: sha512-Zb+DeHA1po8VMiOAAXsBmAHhfWmQttsUkI5oiZUmOXJruRuQ2rVr01NoxHpoEpLwHOABVNzD3PMbwov+g3c7lg==}
|
||||
|
||||
|
|
@ -2048,6 +2057,9 @@ packages:
|
|||
'@lexical/list@0.37.0':
|
||||
resolution: {integrity: sha512-AOC6yAA3mfNvJKbwo+kvAbPJI+13yF2ISA65vbA578CugvJ08zIVgM+pSzxquGhD0ioJY3cXVW7+gdkCP1qu5g==}
|
||||
|
||||
'@lexical/list@0.38.2':
|
||||
resolution: {integrity: sha512-OQm9TzatlMrDZGxMxbozZEHzMJhKxAbH1TOnOGyFfzpfjbnFK2y8oLeVsfQZfZRmiqQS4Qc/rpFnRP2Ax5dsbA==}
|
||||
|
||||
'@lexical/mark@0.36.2':
|
||||
resolution: {integrity: sha512-n0MNXtGH+1i43hglgHjpQV0093HmIiFR7Budg2BJb8ZNzO1KZRqeXAHlA5ZzJ698FkAnS4R5bqG9tZ0JJHgAuA==}
|
||||
|
||||
|
|
@ -2078,21 +2090,33 @@ packages:
|
|||
'@lexical/selection@0.37.0':
|
||||
resolution: {integrity: sha512-Lix1s2r71jHfsTEs4q/YqK2s3uXKOnyA3fd1VDMWysO+bZzRwEO5+qyDvENZ0WrXSDCnlibNFV1HttWX9/zqyw==}
|
||||
|
||||
'@lexical/selection@0.38.2':
|
||||
resolution: {integrity: sha512-eMFiWlBH6bEX9U9sMJ6PXPxVXTrihQfFeiIlWLuTpEIDF2HRz7Uo1KFRC/yN6q0DQaj7d9NZYA6Mei5DoQuz5w==}
|
||||
|
||||
'@lexical/table@0.36.2':
|
||||
resolution: {integrity: sha512-96rNNPiVbC65i+Jn1QzIsehCS7UVUc69ovrh9Bt4+pXDebZSdZai153Q7RUq8q3AQ5ocK4/SA2kLQfMu0grj3Q==}
|
||||
|
||||
'@lexical/table@0.37.0':
|
||||
resolution: {integrity: sha512-g7S8ml8kIujEDLWlzYKETgPCQ2U9oeWqdytRuHjHGi/rjAAGHSej5IRqTPIMxNP3VVQHnBoQ+Y9hBtjiuddhgQ==}
|
||||
|
||||
'@lexical/table@0.38.2':
|
||||
resolution: {integrity: sha512-uu0i7yz0nbClmHOO5ZFsinRJE6vQnFz2YPblYHAlNigiBedhqMwSv5bedrzDq8nTTHwych3mC63tcyKIrM+I1g==}
|
||||
|
||||
'@lexical/text@0.36.2':
|
||||
resolution: {integrity: sha512-IbbqgRdMAD6Uk9b2+qSVoy+8RVcczrz6OgXvg39+EYD+XEC7Rbw7kDTWzuNSJJpP7vxSO8YDZSaIlP5gNH3qKA==}
|
||||
|
||||
'@lexical/text@0.38.2':
|
||||
resolution: {integrity: sha512-+juZxUugtC4T37aE3P0l4I9tsWbogDUnTI/mgYk4Ht9g+gLJnhQkzSA8chIyfTxbj5i0A8yWrUUSw+/xA7lKUQ==}
|
||||
|
||||
'@lexical/utils@0.36.2':
|
||||
resolution: {integrity: sha512-P9+t2Ob10YNGYT/PWEER+1EqH8SAjCNRn+7SBvKbr0IdleGF2JvzbJwAWaRwZs1c18P11XdQZ779dGvWlfwBIw==}
|
||||
|
||||
'@lexical/utils@0.37.0':
|
||||
resolution: {integrity: sha512-CFp4diY/kR5RqhzQSl/7SwsMod1sgLpI1FBifcOuJ6L/S6YywGpEB4B7aV5zqW21A/jU2T+2NZtxSUn6S+9gMg==}
|
||||
|
||||
'@lexical/utils@0.38.2':
|
||||
resolution: {integrity: sha512-y+3rw15r4oAWIEXicUdNjfk8018dbKl7dWHqGHVEtqzAYefnEYdfD2FJ5KOTXfeoYfxi8yOW7FvzS4NZDi8Bfw==}
|
||||
|
||||
'@lexical/yjs@0.36.2':
|
||||
resolution: {integrity: sha512-gZ66Mw+uKXTO8KeX/hNKAinXbFg3gnNYraG76lBXCwb/Ka3q34upIY9FUeGOwGVaau3iIDQhE49I+6MugAX2FQ==}
|
||||
peerDependencies:
|
||||
|
|
@ -4586,10 +4610,10 @@ packages:
|
|||
duplexer@0.1.2:
|
||||
resolution: {integrity: sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==}
|
||||
|
||||
echarts-for-react@3.0.2:
|
||||
resolution: {integrity: sha512-DRwIiTzx8JfwPOVgGttDytBqdp5VzCSyMRIxubgU/g2n9y3VLUmF2FK7Icmg/sNVkv4+rktmrLN9w22U2yy3fA==}
|
||||
echarts-for-react@3.0.5:
|
||||
resolution: {integrity: sha512-YpEI5Ty7O/2nvCfQ7ybNa+S90DwE8KYZWacGvJW4luUqywP7qStQ+pxDlYOmr4jGDu10mhEkiAuMKcUlT4W5vg==}
|
||||
peerDependencies:
|
||||
echarts: ^3.0.0 || ^4.0.0 || ^5.0.0
|
||||
echarts: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0
|
||||
react: ^15.0.0 || >=16.0.0
|
||||
|
||||
echarts@5.6.0:
|
||||
|
|
@ -8445,6 +8469,24 @@ packages:
|
|||
react:
|
||||
optional: true
|
||||
|
||||
zustand@5.0.9:
|
||||
resolution: {integrity: sha512-ALBtUj0AfjJt3uNRQoL1tL2tMvj6Gp/6e39dnfT6uzpelGru8v1tPOGBzayOWbPJvujM8JojDk3E1LxeFisBNg==}
|
||||
engines: {node: '>=12.20.0'}
|
||||
peerDependencies:
|
||||
'@types/react': ~19.1.17
|
||||
immer: '>=9.0.6'
|
||||
react: '>=18.0.0'
|
||||
use-sync-external-store: '>=1.2.0'
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
immer:
|
||||
optional: true
|
||||
react:
|
||||
optional: true
|
||||
use-sync-external-store:
|
||||
optional: true
|
||||
|
||||
zwitch@2.0.4:
|
||||
resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==}
|
||||
|
||||
|
|
@ -10200,6 +10242,14 @@ snapshots:
|
|||
'@lexical/utils': 0.37.0
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/clipboard@0.38.2':
|
||||
dependencies:
|
||||
'@lexical/html': 0.38.2
|
||||
'@lexical/list': 0.38.2
|
||||
'@lexical/selection': 0.38.2
|
||||
'@lexical/utils': 0.38.2
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/code@0.36.2':
|
||||
dependencies:
|
||||
'@lexical/utils': 0.36.2
|
||||
|
|
@ -10234,6 +10284,12 @@ snapshots:
|
|||
'@preact/signals-core': 1.12.1
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/extension@0.38.2':
|
||||
dependencies:
|
||||
'@lexical/utils': 0.38.2
|
||||
'@preact/signals-core': 1.12.1
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/hashtag@0.36.2':
|
||||
dependencies:
|
||||
'@lexical/text': 0.36.2
|
||||
|
|
@ -10258,6 +10314,12 @@ snapshots:
|
|||
'@lexical/utils': 0.37.0
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/html@0.38.2':
|
||||
dependencies:
|
||||
'@lexical/selection': 0.38.2
|
||||
'@lexical/utils': 0.38.2
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/link@0.36.2':
|
||||
dependencies:
|
||||
'@lexical/extension': 0.36.2
|
||||
|
|
@ -10278,6 +10340,13 @@ snapshots:
|
|||
'@lexical/utils': 0.37.0
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/list@0.38.2':
|
||||
dependencies:
|
||||
'@lexical/extension': 0.38.2
|
||||
'@lexical/selection': 0.38.2
|
||||
'@lexical/utils': 0.38.2
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/mark@0.36.2':
|
||||
dependencies:
|
||||
'@lexical/utils': 0.36.2
|
||||
|
|
@ -10351,6 +10420,10 @@ snapshots:
|
|||
dependencies:
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/selection@0.38.2':
|
||||
dependencies:
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/table@0.36.2':
|
||||
dependencies:
|
||||
'@lexical/clipboard': 0.36.2
|
||||
|
|
@ -10365,10 +10438,21 @@ snapshots:
|
|||
'@lexical/utils': 0.37.0
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/table@0.38.2':
|
||||
dependencies:
|
||||
'@lexical/clipboard': 0.38.2
|
||||
'@lexical/extension': 0.38.2
|
||||
'@lexical/utils': 0.38.2
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/text@0.36.2':
|
||||
dependencies:
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/text@0.38.2':
|
||||
dependencies:
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/utils@0.36.2':
|
||||
dependencies:
|
||||
'@lexical/list': 0.36.2
|
||||
|
|
@ -10383,6 +10467,13 @@ snapshots:
|
|||
'@lexical/table': 0.37.0
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/utils@0.38.2':
|
||||
dependencies:
|
||||
'@lexical/list': 0.38.2
|
||||
'@lexical/selection': 0.38.2
|
||||
'@lexical/table': 0.38.2
|
||||
lexical: 0.37.0
|
||||
|
||||
'@lexical/yjs@0.36.2(yjs@13.6.27)':
|
||||
dependencies:
|
||||
'@lexical/offset': 0.36.2
|
||||
|
|
@ -13098,7 +13189,7 @@ snapshots:
|
|||
|
||||
duplexer@0.1.2: {}
|
||||
|
||||
echarts-for-react@3.0.2(echarts@5.6.0)(react@19.1.1):
|
||||
echarts-for-react@3.0.5(echarts@5.6.0)(react@19.1.1):
|
||||
dependencies:
|
||||
echarts: 5.6.0
|
||||
fast-deep-equal: 3.1.3
|
||||
|
|
@ -17931,9 +18022,9 @@ snapshots:
|
|||
dependencies:
|
||||
tslib: 2.3.0
|
||||
|
||||
zundo@2.3.0(zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)):
|
||||
zundo@2.3.0(zustand@5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1))):
|
||||
dependencies:
|
||||
zustand: 4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)
|
||||
zustand: 5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1))
|
||||
|
||||
zustand@4.5.7(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1):
|
||||
dependencies:
|
||||
|
|
@ -17943,4 +18034,11 @@ snapshots:
|
|||
immer: 10.1.3
|
||||
react: 19.1.1
|
||||
|
||||
zustand@5.0.9(@types/react@19.1.17)(immer@10.1.3)(react@19.1.1)(use-sync-external-store@1.6.0(react@19.1.1)):
|
||||
optionalDependencies:
|
||||
'@types/react': 19.1.17
|
||||
immer: 10.1.3
|
||||
react: 19.1.1
|
||||
use-sync-external-store: 1.6.0(react@19.1.1)
|
||||
|
||||
zwitch@2.0.4: {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue