mirror of https://github.com/langgenius/dify.git
Merge branch 'origin-main' into feat/end-user-oauth
This commit is contained in:
commit
f5e36a8a2b
|
|
@ -18,10 +18,10 @@ api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
|||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Novice
|
||||
api/core/workflow/nodes/iteration/ @Novice
|
||||
api/core/workflow/nodes/loop/ @Novice
|
||||
api/core/workflow/nodes/llm/ @Novice
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
|
|
@ -141,7 +141,7 @@ web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
|||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
app_mode: str = Field(..., description="Application mode")
|
||||
model_mode: str = Field(..., description="Model mode")
|
||||
has_context: str = Field(default="true", description="Whether has context")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AdvancedPromptTemplateQuery.__name__,
|
||||
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -18,7 +25,7 @@ parser = (
|
|||
class AdvancedPromptTemplateList(Resource):
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
|
|
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = parser.parse_args()
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -36,6 +39,130 @@ from services.enterprise.enterprise_service import EnterpriseService
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
|
||||
class AppIconPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon data")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
||||
class AppSiteStatusPayload(BaseModel):
|
||||
enable_site: bool = Field(..., description="Enable or disable site")
|
||||
|
||||
|
||||
class AppApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable API")
|
||||
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
|
|
@ -147,22 +274,7 @@ app_pagination_model = console_ns.model(
|
|||
class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -172,42 +284,12 @@ class AppListApi(Resource):
|
|||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument("name", type=str, location="args", required=False)
|
||||
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
|
|
@ -254,19 +336,7 @@ class AppListApi(Resource):
|
|||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
|
|
@ -279,22 +349,10 @@ class AppListApi(Resource):
|
|||
def post(self):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
|
@ -326,20 +384,7 @@ class AppApi(Resource):
|
|||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
|
|
@ -351,28 +396,18 @@ class AppApi(Resource):
|
|||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
.add_argument("max_active_requests", type=int, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
|
|
@ -401,18 +436,7 @@ class AppCopyApi(Resource):
|
|||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -426,15 +450,7 @@ class AppCopyApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
|
|
@ -443,11 +459,11 @@ class AppCopyApi(Resource):
|
|||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -462,11 +478,7 @@ class AppExportApi(Resource):
|
|||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
|
|
@ -480,30 +492,23 @@ class AppExportApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# Add include_secret params
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
.add_argument("workflow_id", type=str, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -512,10 +517,10 @@ class AppNameApi(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = parser.parse_args()
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -525,16 +530,7 @@ class AppIconApi(Resource):
|
|||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -544,15 +540,10 @@ class AppIconApi(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -562,11 +553,7 @@ class AppSiteStatus(Resource):
|
|||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -576,11 +563,10 @@ class AppSiteStatus(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -590,11 +576,7 @@ class AppApiStatus(Resource):
|
|||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -604,11 +586,10 @@ class AppApiStatus(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -631,15 +612,7 @@ class AppTraceApi(Resource):
|
|||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -648,17 +621,12 @@ class AppTraceApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("enabled", type=bool, required=True, location="json")
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService
|
|||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(default="", description="Query text")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(..., description="User query")
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
|
|
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
|
|||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
|
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
@ -137,21 +154,7 @@ class ChatMessageApi(Resource):
|
|||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
|
|
@ -161,20 +164,10 @@ class ChatMessageApi(Resource):
|
|||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
|
@ -14,13 +16,54 @@ from extensions.ext_database import db
|
|||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString, TimestampField
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||
default="all", description="Annotation status filter"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class CompletionConversationQuery(BaseConversationQuery):
|
||||
pass
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
|
|
@ -283,22 +326,7 @@ class CompletionConversationApi(Resource):
|
|||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -309,32 +337,17 @@ class CompletionConversationApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args["keyword"]:
|
||||
if args.keyword:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args['keyword']}%"),
|
||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -342,7 +355,7 @@ class CompletionConversationApi(Resource):
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -354,11 +367,11 @@ class CompletionConversationApi(Resource):
|
|||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
|
|
@ -367,7 +380,7 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
|
@ -419,31 +432,7 @@ class ChatConversationApi(Resource):
|
|||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -454,31 +443,7 @@ class ChatConversationApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
|
|
@ -490,8 +455,8 @@ class ChatConversationApi(Resource):
|
|||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
|
|
@ -514,12 +479,12 @@ class ChatConversationApi(Resource):
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
|
|
@ -527,35 +492,35 @@ class ChatConversationApi(Resource):
|
|||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
|
|
@ -567,7 +532,7 @@ class ChatConversationApi(Resource):
|
|||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,6 +16,18 @@ from libs.login import login_required
|
|||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
|
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
|
|||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
|
|||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
|
|
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
|
|||
class RuleCodeGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
|
|||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
class InstructionGenerateApi(Resource):
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
.add_argument("current", type=str, required=False, default="", location="json")
|
||||
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
node = [node for node in nodes if node["id"] == args.node_id]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
return {"error": f"node {args.node_id} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
flow_id=args.flow_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
flow_id=args.flow_id,
|
||||
node_id=args.node_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
|
|
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
|
|||
class InstructionGenerationTemplateApi(Resource):
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||
match args.type:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
|
|
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
|
|||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
|
@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
|||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("first_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class FeedbackExportQuery(BaseModel):
|
||||
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||
|
||||
@field_validator("has_comment", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||
if isinstance(value, bool) or value is None:
|
||||
return value
|
||||
lowered = value.lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -157,12 +220,7 @@ class ChatMessageListApi(Resource):
|
|||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
|
|
@ -172,27 +230,21 @@ class ChatMessageListApi(Resource):
|
|||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args["first_id"]:
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -207,7 +259,7 @@ class ChatMessageListApi(Resource):
|
|||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
|
|
@ -215,12 +267,12 @@ class ChatMessageListApi(Resource):
|
|||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
if len(history_messages) == args.limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
|
|
@ -238,7 +290,7 @@ class ChatMessageListApi(Resource):
|
|||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
|
|
@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource):
|
|||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
|
|
@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource):
|
|||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
|
|
@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
rating_value = args.rating
|
||||
if rating_value is None:
|
||||
raise ValueError("rating is required to create feedback")
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
rating=rating_value,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
|
@ -369,24 +411,12 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
return {"data": questions}
|
||||
|
||||
|
||||
# Shared parser for feedback export (used for both documentation and runtime parsing)
|
||||
feedback_export_parser = (
|
||||
console_ns.parser()
|
||||
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
|
||||
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
|
||||
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
|
||||
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
|
||||
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(feedback_export_parser)
|
||||
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
|
|
@ -395,7 +425,7 @@ class MessageFeedbackExportApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = feedback_export_parser.parse_args()
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
|
@ -403,12 +433,12 @@ class MessageFeedbackExportApi(Resource):
|
|||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.get("from_source"),
|
||||
rating=args.get("rating"),
|
||||
has_comment=args.get("has_comment"),
|
||||
start_date=args.get("start_date"),
|
||||
end_date=args.get("end_date"),
|
||||
format_type=args.get("format", "csv"),
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
format_type=args.format,
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
|
|
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -57,7 +69,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -81,19 +93,12 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
|
|
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -121,7 +126,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
|
|
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -177,7 +182,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
|
|
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -235,7 +240,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
|
|
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -302,7 +307,7 @@ FROM
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
|
|
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -374,7 +379,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
|
|
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -436,7 +441,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
|
|
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -495,7 +500,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None:
|
|||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
features: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BaseWorkflowRunPayload(BaseModel):
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any] | None = None
|
||||
query: str = ""
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class IterationNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LoopNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
|
||||
|
||||
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
# of concerns and make the code more maintainable.
|
||||
|
|
@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource):
|
|||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncDraftWorkflowRequest",
|
||||
{
|
||||
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
|
||||
"features": fields.Raw(required=True, description="Workflow features configuration"),
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
|
|
@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource):
|
|||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
}
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
|
|
@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AdvancedChatWorkflowRunRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow run started successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json", default="")
|
||||
.add_argument("files", type=list, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
|
|
@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"IterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
|
|
@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
@console_ns.doc("run_workflow_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowIterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
|
|
@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"LoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
@console_ns.doc("run_workflow_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowLoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
@console_ns.doc("run_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||
@console_ns.response(200, "Draft workflow run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
|
|
@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
|
|
@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
@console_ns.doc("run_draft_workflow_node")
|
||||
@console_ns.doc(description="Run draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowNodeRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
Run draft workflow node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("query", type=str, required=False, location="json", default="")
|
||||
.add_argument("files", type=list, location="json", default=[])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
user_inputs = args.get("inputs")
|
||||
user_inputs = args_model.inputs
|
||||
if user_inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
|
|
@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_published_workflow")
|
||||
|
|
@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource):
|
|||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@console_ns.expect(parser_publish)
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
|
|||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@console_ns.doc("get_default_block_config")
|
||||
|
|
@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
|
|||
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@console_ns.response(200, "Default block configuration retrieved successfully")
|
||||
@console_ns.response(404, "Block type not found")
|
||||
@console_ns.expect(parser_block)
|
||||
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
if args.q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
filters = json.loads(args.q)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
|
|
@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
|
|||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@console_ns.expect(parser_convert)
|
||||
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if request.data:
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
payload = console_ns.payload or {}
|
||||
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(parser_workflows)
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource):
|
|||
@console_ns.doc("update_workflow_by_id")
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateWorkflowRequest",
|
||||
{
|
||||
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource):
|
|||
Update workflow attributes
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
if args.marked_name is not None:
|
||||
update_data["marked_name"] = args.marked_name
|
||||
if args.marked_comment is not None:
|
||||
update_data["marked_comment"] = args.marked_comment
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
|
@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_id", type=str, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||
node_id = args.node_id
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
|
@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
@console_ns.doc("draft_workflow_trigger_run_all")
|
||||
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
|
||||
@console_ns.response(200, "Workflow executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
|
|
@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_ids", type=list, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||
node_ids = args.node_ids
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from datetime import datetime
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -14,6 +17,48 @@ from models import App
|
|||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
status: WorkflowExecutionStatus | None = Field(
|
||||
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||
)
|
||||
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return isoparse(value) # type: ignore
|
||||
|
||||
@field_validator("detail", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
|
|
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
|
|||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
"created_at__before": "Filter logs created before this timestamp",
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
|
|||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
|
|||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
Parse common arguments for workflow run list endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_workflow_run_count_args():
|
||||
"""
|
||||
Parse common arguments for workflow run count endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@classmethod
|
||||
def validate_time_range(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
|
|
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args = _parse_workflow_run_list_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args = _parse_workflow_run_list_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
|
|||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
|
|
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class VersionApi(Resource):
|
|||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
|
|
|
|||
|
|
@ -174,63 +174,25 @@ class CheckEmailUniquePayload(BaseModel):
|
|||
return email(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceLanguagePayload.__name__,
|
||||
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceThemePayload.__name__,
|
||||
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountTimezonePayload.__name__,
|
||||
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountPasswordPayload.__name__,
|
||||
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletePayload.__name__,
|
||||
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletionFeedbackPayload.__name__,
|
||||
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationActivatePayload.__name__,
|
||||
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationAutocompleteQuery.__name__,
|
||||
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailSendPayload.__name__,
|
||||
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailValidityPayload.__name__,
|
||||
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailResetPayload.__name__,
|
||||
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CheckEmailUniquePayload.__name__,
|
||||
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
reg(AccountPasswordPayload)
|
||||
reg(AccountDeletePayload)
|
||||
reg(AccountDeletionFeedbackPayload)
|
||||
reg(EducationActivatePayload)
|
||||
reg(EducationAutocompleteQuery)
|
||||
reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
|
@ -7,21 +11,49 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EndpointCreatePayload(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointListQuery(BaseModel):
|
||||
page: int = Field(ge=1)
|
||||
page_size: int = Field(gt=0)
|
||||
|
||||
|
||||
class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointCreateRequest",
|
||||
{
|
||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
||||
"settings": fields.Raw(required=True, description="Endpoint settings"),
|
||||
"name": fields.String(required=True, description="Endpoint name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
|
|
@ -35,26 +67,16 @@ class EndpointCreateApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
|
|
@ -65,11 +87,7 @@ class EndpointCreateApi(Resource):
|
|||
class EndpointListApi(Resource):
|
||||
@console_ns.doc("list_endpoints")
|
||||
@console_ns.doc(description="List plugin endpoints with pagination")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointListQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -83,15 +101,10 @@ class EndpointListApi(Resource):
|
|||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -109,12 +122,7 @@ class EndpointListApi(Resource):
|
|||
class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.doc("list_plugin_endpoints")
|
||||
@console_ns.doc(description="List endpoints for a specific plugin")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -128,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
plugin_id = args.plugin_id
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -157,11 +159,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
|
|
@ -175,13 +173,12 @@ class EndpointDeleteApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -189,16 +186,7 @@ class EndpointDeleteApi(Resource):
|
|||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointUpdateRequest",
|
||||
{
|
||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
||||
"settings": fields.Raw(required=True, description="Updated settings"),
|
||||
"name": fields.String(required=True, description="Updated name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
|
|
@ -212,25 +200,15 @@ class EndpointUpdateApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("endpoint_id", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -239,11 +217,7 @@ class EndpointUpdateApi(Resource):
|
|||
class EndpointEnableApi(Resource):
|
||||
@console_ns.doc("enable_endpoint")
|
||||
@console_ns.doc(description="Enable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
|
|
@ -257,13 +231,12 @@ class EndpointEnableApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -271,11 +244,7 @@ class EndpointEnableApi(Resource):
|
|||
class EndpointDisableApi(Resource):
|
||||
@console_ns.doc("disable_endpoint")
|
||||
@console_ns.doc(description="Disable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
|
|
@ -289,11 +258,10 @@ class EndpointDisableApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,26 +58,15 @@ class OwnerTransferPayload(BaseModel):
|
|||
token: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
MemberInvitePayload.__name__,
|
||||
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
MemberRoleUpdatePayload.__name__,
|
||||
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferEmailPayload.__name__,
|
||||
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferCheckPayload.__name__,
|
||||
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferPayload.__name__,
|
||||
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(MemberInvitePayload)
|
||||
reg(MemberRoleUpdatePayload)
|
||||
reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
|
|
|
|||
|
|
@ -75,44 +75,18 @@ class ParserPreferredProviderType(BaseModel):
|
|||
preferred_provider_type: Literal["system", "custom"]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialId.__name__,
|
||||
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialCreate.__name__,
|
||||
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialUpdate.__name__,
|
||||
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialDelete.__name__,
|
||||
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialSwitch.__name__,
|
||||
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialValidate.__name__,
|
||||
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferredProviderType.__name__,
|
||||
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
reg(ParserModelList)
|
||||
reg(ParserCredentialId)
|
||||
reg(ParserCredentialCreate)
|
||||
reg(ParserCredentialUpdate)
|
||||
reg(ParserCredentialDelete)
|
||||
reg(ParserCredentialSwitch)
|
||||
reg(ParserCredentialValidate)
|
||||
reg(ParserPreferredProviderType)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
|
|
|
|||
|
|
@ -32,25 +32,11 @@ class ParserPostDefault(BaseModel):
|
|||
model_settings: list[Inner]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserDeleteModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingPayload(BaseModel):
|
||||
configs: list[dict[str, Any]] | None = None
|
||||
enabled: bool | None = None
|
||||
|
|
@ -119,33 +105,19 @@ class ParserParameter(BaseModel):
|
|||
model: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetCredentials.__name__,
|
||||
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCreateCredential.__name__,
|
||||
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUpdateCredential.__name__,
|
||||
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteCredential.__name__,
|
||||
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
reg(ParserGetDefault)
|
||||
reg(ParserPostDefault)
|
||||
reg(ParserDeleteModels)
|
||||
reg(ParserPostModels)
|
||||
reg(ParserGetCredentials)
|
||||
reg(ParserCreateCredential)
|
||||
reg(ParserUpdateCredential)
|
||||
reg(ParserDeleteCredential)
|
||||
reg(ParserParameter)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ from services.plugin.plugin_service import PluginService
|
|||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -46,9 +50,7 @@ class ParserList(BaseModel):
|
|||
page_size: int = Field(default=256)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
reg(ParserList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
|
|
@ -72,11 +74,6 @@ class ParserLatest(BaseModel):
|
|||
plugin_ids: list[str]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserIcon(BaseModel):
|
||||
tenant_id: str
|
||||
filename: str
|
||||
|
|
@ -173,72 +170,22 @@ class ParserReadme(BaseModel):
|
|||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifiers.__name__,
|
||||
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifierQuery.__name__,
|
||||
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserMarketplaceUpgrade.__name__,
|
||||
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPermissionChange.__name__,
|
||||
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDynamicOptions.__name__,
|
||||
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferencesChange.__name__,
|
||||
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserExcludePlugin.__name__,
|
||||
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
reg(ParserLatest)
|
||||
reg(ParserIcon)
|
||||
reg(ParserAsset)
|
||||
reg(ParserGithubUpload)
|
||||
reg(ParserPluginIdentifiers)
|
||||
reg(ParserGithubInstall)
|
||||
reg(ParserPluginIdentifierQuery)
|
||||
reg(ParserTasks)
|
||||
reg(ParserMarketplaceUpgrade)
|
||||
reg(ParserGithubUpgrade)
|
||||
reg(ParserUninstall)
|
||||
reg(ParserPermissionChange)
|
||||
reg(ParserDynamicOptions)
|
||||
reg(ParserPreferencesChange)
|
||||
reg(ParserExcludePlugin)
|
||||
reg(ParserReadme)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
|
|
|
|||
|
|
@ -54,25 +54,14 @@ class WorkspaceInfoPayload(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
SwitchWorkspacePayload.__name__,
|
||||
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceCustomConfigPayload.__name__,
|
||||
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceInfoPayload.__name__,
|
||||
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
|
|
|
|||
|
|
@ -270,6 +270,10 @@ class OceanBaseVector(BaseVector):
|
|||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
self._hnsw_ef_search = ef_search
|
||||
topk = kwargs.get("top_k", 10)
|
||||
try:
|
||||
score_threshold = float(val) if (val := kwargs.get("score_threshold")) is not None else 0.0
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid score_threshold parameter: {e}") from e
|
||||
try:
|
||||
cur = self._client.ann_search(
|
||||
table_name=self._collection_name,
|
||||
|
|
@ -285,14 +289,20 @@ class OceanBaseVector(BaseVector):
|
|||
raise Exception("Failed to search by vector. ", e)
|
||||
docs = []
|
||||
for _text, metadata, distance in cur:
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = 1 - distance / math.sqrt(2)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=_text,
|
||||
metadata=metadata,
|
||||
score = 1 - distance / math.sqrt(2)
|
||||
if score >= score_threshold:
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata)
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=_text,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ class ToolProviderApiEntity(BaseModel):
|
|||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
# Workflow
|
||||
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -87,6 +89,8 @@ class ToolProviderApiEntity(BaseModel):
|
|||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
|
|
|||
|
|
@ -240,23 +240,23 @@ class Node(Generic[NodeDataT]):
|
|||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
|
|
@ -265,7 +265,7 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
|
|
@ -419,10 +419,6 @@ class Node(Generic[NodeDataT]):
|
|||
"""Get the default values dictionary for this node."""
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
return self._node_data
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
|
|
@ -548,7 +544,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
|
|
@ -561,7 +557,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
|
@ -572,7 +568,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -586,7 +582,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -601,7 +597,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
|
|
@ -614,7 +610,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
|
@ -625,7 +621,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -639,7 +635,7 @@ class Node(Generic[NodeDataT]):
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
|
|||
|
|
@ -201,7 +201,9 @@ class ToolTransformService:
|
|||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
provider_controller: WorkflowToolProviderController,
|
||||
labels: list[str] | None = None,
|
||||
workflow_app_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
|
|
@ -221,6 +223,7 @@ class ToolTransformService:
|
|||
plugin_unique_identifier=None,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -189,6 +189,9 @@ class WorkflowToolManageService:
|
|||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
# Create a mapping from provider_id to app_id
|
||||
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
|
|
@ -202,8 +205,11 @@ class WorkflowToolManageService:
|
|||
result = []
|
||||
|
||||
for tool in tools:
|
||||
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
provider_controller=tool,
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Unit tests for ToolProviderApiEntity workflow_app_id field.
|
||||
|
||||
This test suite covers:
|
||||
- ToolProviderApiEntity workflow_app_id field creation and default value
|
||||
- ToolProviderApiEntity.to_dict() method behavior with workflow_app_id
|
||||
"""
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class TestToolProviderApiEntityWorkflowAppId:
|
||||
"""Test suite for ToolProviderApiEntity workflow_app_id field."""
|
||||
|
||||
def test_workflow_app_id_field_default_none(self):
|
||||
"""Test that workflow_app_id defaults to None when not provided."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
)
|
||||
|
||||
assert entity.workflow_app_id is None
|
||||
|
||||
def test_to_dict_includes_workflow_app_id_when_workflow_type_and_has_value(self):
|
||||
"""Test that to_dict() includes workflow_app_id when type is WORKFLOW and value is set."""
|
||||
workflow_app_id = "app_123"
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" in result
|
||||
assert result["workflow_app_id"] == workflow_app_id
|
||||
|
||||
def test_to_dict_excludes_workflow_app_id_when_workflow_type_and_none(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when type is WORKFLOW but value is None."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
||||
def test_to_dict_excludes_workflow_app_id_when_not_workflow_type(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when type is not WORKFLOW."""
|
||||
workflow_app_id = "app_123"
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
||||
def test_to_dict_includes_workflow_app_id_for_workflow_type_with_empty_string(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when value is empty string (falsy)."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id="",
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
|
@ -744,7 +744,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
|||
)
|
||||
|
||||
llm_node = graph.nodes["llm"]
|
||||
base_node_data = llm_node.get_base_node_data()
|
||||
base_node_data = llm_node.node_data
|
||||
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||
|
||||
|
|
|
|||
|
|
@ -471,8 +471,8 @@ class TestCodeNodeInitialization:
|
|||
|
||||
assert node._get_description() is None
|
||||
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
def test_node_data_property(self):
|
||||
"""Test node_data property returns node data."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
node._node_data = CodeNodeData(
|
||||
title="Base Test",
|
||||
|
|
@ -482,7 +482,7 @@ class TestCodeNodeInitialization:
|
|||
outputs={},
|
||||
)
|
||||
|
||||
result = node.get_base_node_data()
|
||||
result = node.node_data
|
||||
|
||||
assert result == node._node_data
|
||||
assert result.title == "Base Test"
|
||||
|
|
|
|||
|
|
@ -240,8 +240,8 @@ class TestIterationNodeInitialization:
|
|||
|
||||
assert node._get_description() == "This is a description"
|
||||
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
def test_node_data_property(self):
|
||||
"""Test node_data property returns node data."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Base Test",
|
||||
|
|
@ -249,7 +249,7 @@ class TestIterationNodeInitialization:
|
|||
output_selector=["y"],
|
||||
)
|
||||
|
||||
result = node.get_base_node_data()
|
||||
result = node.node_data
|
||||
|
||||
assert result == node._node_data
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,718 @@
|
|||
"""
|
||||
Comprehensive unit tests for AudioService.
|
||||
|
||||
This test suite provides complete coverage of audio processing operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Speech-to-Text (ASR) Operations (TestAudioServiceASR)
|
||||
Tests audio transcription functionality:
|
||||
- Successful transcription for different app modes
|
||||
- File validation (size, type, presence)
|
||||
- Feature flag validation (speech-to-text enabled)
|
||||
- Error handling for various failure scenarios
|
||||
- Model instance availability checks
|
||||
|
||||
### 2. Text-to-Speech (TTS) Operations (TestAudioServiceTTS)
|
||||
Tests text-to-audio conversion:
|
||||
- TTS with text input
|
||||
- TTS with message ID
|
||||
- Voice selection (explicit and default)
|
||||
- Feature flag validation (text-to-speech enabled)
|
||||
- Draft workflow handling
|
||||
- Streaming response handling
|
||||
- Error handling for missing/invalid inputs
|
||||
|
||||
### 3. TTS Voice Listing (TestAudioServiceTTSVoices)
|
||||
Tests available voice retrieval:
|
||||
- Get available voices for a tenant
|
||||
- Language filtering
|
||||
- Error handling for missing provider
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (ModelManager, db, FileStorage) are mocked
|
||||
for fast, isolated unit tests
|
||||
- **Factory Pattern**: AudioServiceTestDataFactory provides consistent test data
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values, side effects, and error conditions
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Audio Formats:**
|
||||
- Supported: mp3, wav, m4a, flac, ogg, opus, webm
|
||||
- File size limit: 30 MB
|
||||
|
||||
**App Modes:**
|
||||
- ADVANCED_CHAT/WORKFLOW: Use workflow features
|
||||
- CHAT/COMPLETION: Use app_model_config
|
||||
|
||||
**Feature Flags:**
|
||||
- speech_to_text: Enables ASR functionality
|
||||
- text_to_speech: Enables TTS functionality
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, AppModelConfig, Message
|
||||
from models.workflow import Workflow
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class AudioServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
audio-related operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
mode: AppMode = AppMode.CHAT,
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock App object.
|
||||
|
||||
Args:
|
||||
app_id: Unique identifier for the app
|
||||
mode: App mode (CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
|
||||
tenant_id: Tenant identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock App object with specified attributes
|
||||
"""
|
||||
app = create_autospec(App, instance=True)
|
||||
app.id = app_id
|
||||
app.mode = mode
|
||||
app.tenant_id = tenant_id
|
||||
app.workflow = kwargs.get("workflow")
|
||||
app.app_model_config = kwargs.get("app_model_config")
|
||||
for key, value in kwargs.items():
|
||||
setattr(app, key, value)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_mock(features_dict: dict | None = None, **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock Workflow object.
|
||||
|
||||
Args:
|
||||
features_dict: Dictionary of workflow features
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Workflow object with specified attributes
|
||||
"""
|
||||
workflow = create_autospec(Workflow, instance=True)
|
||||
workflow.features_dict = features_dict or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(workflow, key, value)
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def create_app_model_config_mock(
|
||||
speech_to_text_dict: dict | None = None,
|
||||
text_to_speech_dict: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock AppModelConfig object.
|
||||
|
||||
Args:
|
||||
speech_to_text_dict: Speech-to-text configuration
|
||||
text_to_speech_dict: Text-to-speech configuration
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock AppModelConfig object with specified attributes
|
||||
"""
|
||||
config = create_autospec(AppModelConfig, instance=True)
|
||||
config.speech_to_text_dict = speech_to_text_dict or {"enabled": False}
|
||||
config.text_to_speech_dict = text_to_speech_dict or {"enabled": False}
|
||||
for key, value in kwargs.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def create_file_storage_mock(
|
||||
filename: str = "test.mp3",
|
||||
mimetype: str = "audio/mp3",
|
||||
content: bytes = b"fake audio content",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock FileStorage object.
|
||||
|
||||
Args:
|
||||
filename: Name of the file
|
||||
mimetype: MIME type of the file
|
||||
content: File content as bytes
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock FileStorage object with specified attributes
|
||||
"""
|
||||
file = Mock(spec=FileStorage)
|
||||
file.filename = filename
|
||||
file.mimetype = mimetype
|
||||
file.read = Mock(return_value=content)
|
||||
for key, value in kwargs.items():
|
||||
setattr(file, key, value)
|
||||
return file
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-123",
|
||||
answer: str = "Test answer",
|
||||
status: MessageStatus = MessageStatus.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Message object.
|
||||
|
||||
Args:
|
||||
message_id: Unique identifier for the message
|
||||
answer: Message answer text
|
||||
status: Message status
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Message object with specified attributes
|
||||
"""
|
||||
message = create_autospec(Message, instance=True)
|
||||
message.id = message_id
|
||||
message.answer = answer
|
||||
message.status = status
|
||||
for key, value in kwargs.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return AudioServiceTestDataFactory
|
||||
|
||||
|
||||
class TestAudioServiceASR:
|
||||
"""Test speech-to-text (ASR) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_asr(app_model=app, file=file, end_user="user-123")
|
||||
|
||||
# Assert
|
||||
assert result == {"text": "Transcribed text"}
|
||||
mock_model_instance.invoke_speech2text.assert_called_once()
|
||||
call_args = mock_model_instance.invoke_speech2text.call_args
|
||||
assert call_args.kwargs["user"] == "user-123"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=workflow,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
# Assert
|
||||
assert result == {"text": "Workflow transcribed text"}
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
workflow=workflow,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
|
||||
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
workflow=None,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
|
||||
"""Test that ASR raises error when no file is uploaded."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoAudioUploadedServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=None)
|
||||
|
||||
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
|
||||
"""Test that ASR raises error for unsupported audio file types."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock(mimetype="video/mp4")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UnsupportedAudioTypeServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_for_large_file(self, factory):
|
||||
"""Test that ASR raises error when file exceeds size limit (30MB)."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
# Create file larger than 30MB
|
||||
large_content = b"x" * (31 * 1024 * 1024)
|
||||
file = factory.create_file_storage_mock(content=large_content)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that ASR raises error when no model instance is available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
|
||||
class TestAudioServiceTTS:
|
||||
"""Test text-to-speech (TTS) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful TTS with text input."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Hello world",
|
||||
voice="en-US-Neural",
|
||||
end_user="user-123",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="Hello world",
|
||||
user="user-123",
|
||||
tenant_id=app.tenant_id,
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
|
||||
"""Test successful TTS with message ID."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
message = factory.create_message_mock(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
answer="Message answer text",
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio from message"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio from message"
|
||||
mock_model_instance.invoke_tts.assert_called_once()
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
|
||||
"""Test TTS uses default voice when none specified."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "default-voice"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
# Verify default voice was used
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "default-voice"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
|
||||
"""Test TTS gets first available voice when none is configured."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True} # No voice specified
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "auto-voice"
|
||||
|
||||
@patch("services.audio_service.WorkflowService")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_workflow_mode_with_draft(
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory
|
||||
):
|
||||
"""Test TTS in WORKFLOW mode with draft workflow."""
|
||||
# Arrange
|
||||
draft_workflow = factory.create_workflow_mock(
|
||||
features_dict={"text_to_speech": {"enabled": True, "voice": "draft-voice"}}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
)
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow_service = MagicMock()
|
||||
mock_workflow_service_class.return_value = mock_workflow_service
|
||||
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"draft audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Draft test",
|
||||
is_draft=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"draft audio"
|
||||
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
|
||||
|
||||
def test_transcript_tts_raises_error_when_text_missing(self, factory):
|
||||
"""Test that TTS raises error when text is missing."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None for invalid message ID format."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="invalid-uuid",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message answer is empty."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
message = factory.create_message_mock(
|
||||
answer="",
|
||||
status=MessageStatus.NORMAL,
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS raises error when no voices are available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True} # No voice specified
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [] # No voices available
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Sorry, no voice available"):
|
||||
AudioService.transcript_tts(app_model=app, text="Test")
|
||||
|
||||
|
||||
class TestAudioServiceTTSVoices:
|
||||
"""Test TTS voice listing operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful retrieval of TTS voices."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
expected_voices = [
|
||||
{"name": "Voice 1", "value": "voice-1"},
|
||||
{"name": "Voice 2", "value": "voice-2"},
|
||||
]
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = expected_voices
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
# Assert
|
||||
assert result == expected_voices
|
||||
mock_model_instance.get_tts_voices.assert_called_once_with(language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices raises error when no model instance is available."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices propagates exceptions from model instance."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Model error"):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
|
@ -0,0 +1,494 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
class TestEndUserServiceFactory:
|
||||
"""Factory class for creating test data and mock objects for end user service tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
name: str = "Test App",
|
||||
) -> MagicMock:
|
||||
"""Create a mock App object."""
|
||||
app = MagicMock(spec=App)
|
||||
app.id = app_id
|
||||
app.tenant_id = tenant_id
|
||||
app.name = name
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-123",
|
||||
session_id: str = "session-001",
|
||||
type: InvokeFrom = InvokeFrom.SERVICE_API,
|
||||
is_anonymous: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock EndUser object."""
|
||||
end_user = MagicMock(spec=EndUser)
|
||||
end_user.id = user_id
|
||||
end_user.tenant_id = tenant_id
|
||||
end_user.app_id = app_id
|
||||
end_user.session_id = session_id
|
||||
end_user.type = type
|
||||
end_user.is_anonymous = is_anonymous
|
||||
end_user.external_user_id = session_id
|
||||
return end_user
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUser:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user method.
|
||||
|
||||
This test suite covers:
|
||||
- Creating new end users
|
||||
- Retrieving existing end users
|
||||
- Default session ID handling
|
||||
- Anonymous user creation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
# Test 01: Get or create with custom user_id
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test getting or creating end user with custom user_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user_id = "custom-user-123"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
# Verify the created user has correct attributes
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.tenant_id == app.tenant_id
|
||||
assert added_user.app_id == app.id
|
||||
assert added_user.session_id == user_id
|
||||
assert added_user.type == InvokeFrom.SERVICE_API
|
||||
assert added_user.is_anonymous is False
|
||||
|
||||
# Test 02: Get or create without user_id (default session)
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test getting or creating end user without user_id uses default session."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert added_user._is_anonymous is True
|
||||
|
||||
# Test 03: Get existing end user
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
|
||||
"""Test retrieving an existing end user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user_id = "existing-user-123"
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
mock_session.add.assert_not_called() # Should not create new user
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user_by_type method.
|
||||
|
||||
This test suite covers:
|
||||
- Creating end users with different InvokeFrom types
|
||||
- Type migration for legacy users
|
||||
- Query ordering and prioritization
|
||||
- Session management
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
# Test 04: Create new end user with SERVICE_API type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating new end user with SERVICE_API type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == InvokeFrom.SERVICE_API
|
||||
assert added_user.tenant_id == tenant_id
|
||||
assert added_user.app_id == app_id
|
||||
assert added_user.session_id == user_id
|
||||
|
||||
# Test 05: Create new end user with WEB_APP type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating new end user with WEB_APP type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.WEB_APP,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == InvokeFrom.WEB_APP
|
||||
|
||||
# Test 06: Upgrade legacy end user type
|
||||
@patch("services.end_user_service.logger")
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
|
||||
"""Test upgrading legacy end user with different type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
# Existing user with old type
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act - Request with different type
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.WEB_APP,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_logger.info.assert_called_once()
|
||||
# Verify log message contains upgrade info
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Upgrading legacy EndUser" in log_call
|
||||
|
||||
# Test 07: Get existing end user with matching type (no upgrade needed)
|
||||
@patch("services.end_user_service.logger")
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
|
||||
"""Test retrieving existing end user with matching type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
session_id=user_id,
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act - Request with same type
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == InvokeFrom.SERVICE_API
|
||||
# No commit should be called (no type update needed)
|
||||
mock_session.commit.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
# Test 08: Create anonymous user with default session ID
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating anonymous user when user_id is None."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert added_user._is_anonymous is True
|
||||
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
# Test 09: Query ordering prioritizes matching type
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
|
||||
"""Test that query ordering prioritizes records with matching type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify order_by was called (for type prioritization)
|
||||
mock_query.order_by.assert_called_once()
|
||||
|
||||
# Test 10: Session context manager properly closes
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_session_context_manager_closes(self, mock_db, mock_session_class, factory):
|
||||
"""Test that Session context manager is properly used."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify context manager was entered and exited
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
# Test 11: External user ID matches session ID
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test that external_user_id is set to match session_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "custom-external-id"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.external_user_id == user_id
|
||||
assert added_user.session_id == user_id
|
||||
|
||||
# Test 12: Different InvokeFrom types
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_type",
|
||||
[
|
||||
InvokeFrom.SERVICE_API,
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.EXPLORE,
|
||||
InvokeFrom.DEBUGGER,
|
||||
],
|
||||
)
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
|
||||
"""Test creating end users with different InvokeFrom types."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=invoke_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == invoke_type
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Comprehensive unit tests for RecommendedAppService.
|
||||
|
||||
This test suite provides complete coverage of recommended app operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
|
||||
Tests fetching recommended apps with categories:
|
||||
- Successful retrieval with recommended apps
|
||||
- Fallback to builtin when no recommended apps
|
||||
- Different language support
|
||||
- Factory mode selection (remote, builtin, db)
|
||||
- Empty result handling
|
||||
|
||||
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
|
||||
Tests fetching individual app details:
|
||||
- Successful app detail retrieval
|
||||
- Different factory modes
|
||||
- App not found scenarios
|
||||
- Language-specific details
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
|
||||
are mocked for fast, isolated unit tests
|
||||
- **Factory Pattern**: Tests verify correct factory selection based on mode
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and factory method calls
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Factory Modes:**
|
||||
- remote: Fetch from remote API
|
||||
- builtin: Use built-in templates
|
||||
- db: Fetch from database
|
||||
|
||||
**Fallback Logic:**
|
||||
- If remote/db returns no apps, fallback to builtin en-US templates
|
||||
- Ensures users always see some recommended apps
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
class RecommendedAppServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
recommended app operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_recommended_apps_response(
|
||||
recommended_apps: list[dict] | None = None,
|
||||
categories: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for recommended apps.
|
||||
|
||||
Args:
|
||||
recommended_apps: List of recommended app dictionaries
|
||||
categories: List of category names
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended_apps and categories
|
||||
"""
|
||||
if recommended_apps is None:
|
||||
recommended_apps = [
|
||||
{
|
||||
"id": "app-1",
|
||||
"name": "Test App 1",
|
||||
"description": "Test description 1",
|
||||
"category": "productivity",
|
||||
},
|
||||
{
|
||||
"id": "app-2",
|
||||
"name": "Test App 2",
|
||||
"description": "Test description 2",
|
||||
"category": "communication",
|
||||
},
|
||||
]
|
||||
if categories is None:
|
||||
categories = ["productivity", "communication", "utilities"]
|
||||
|
||||
return {
|
||||
"recommended_apps": recommended_apps,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_app_detail_response(
|
||||
app_id: str = "app-123",
|
||||
name: str = "Test App",
|
||||
description: str = "Test description",
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for app detail.
|
||||
|
||||
Args:
|
||||
app_id: App identifier
|
||||
name: App name
|
||||
description: App description
|
||||
**kwargs: Additional fields
|
||||
|
||||
Returns:
|
||||
Dictionary with app details
|
||||
"""
|
||||
detail = {
|
||||
"id": app_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": kwargs.get("category", "productivity"),
|
||||
"icon": kwargs.get("icon", "🚀"),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
}
|
||||
detail.update(kwargs)
|
||||
return detail
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return RecommendedAppServiceTestDataFactory
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
"""Test get_recommended_apps_and_categories operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of recommended apps when apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
expected_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock factory and retrieval instance
|
||||
mock_retrieval_instance = MagicMock()
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_retrieval_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
assert len(result["categories"]) == 3
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback to builtin when no recommended apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
# Remote returns empty recommended_apps
|
||||
empty_response = {"recommended_apps": [], "categories": []}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
|
||||
)
|
||||
|
||||
# Mock remote retrieval instance (returns empty)
|
||||
mock_remote_instance = MagicMock()
|
||||
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
|
||||
|
||||
mock_remote_factory = MagicMock()
|
||||
mock_remote_factory.return_value = mock_remote_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
assert len(result["recommended_apps"]) == 1
|
||||
assert result["recommended_apps"][0]["id"] == "builtin-1"
|
||||
# Verify fallback was called with en-US (hardcoded)
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback when recommended_apps key is None."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
|
||||
|
||||
# Response with None recommended_apps
|
||||
none_response = {"recommended_apps": None, "categories": ["test"]}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock db retrieval instance (returns None)
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
|
||||
|
||||
mock_db_factory = MagicMock()
|
||||
mock_db_factory.return_value = mock_db_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
|
||||
"""Test retrieval with different language codes."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
|
||||
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
|
||||
|
||||
for language in languages:
|
||||
# Create language-specific response
|
||||
lang_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(language)
|
||||
|
||||
# Assert
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that correct factory is selected based on mode."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
"""Test get_recommend_app_detail operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of app detail."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "app-123"
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Productivity App",
|
||||
description="A great productivity app",
|
||||
category="productivity",
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_detail
|
||||
assert result["id"] == app_id
|
||||
assert result["name"] == "Productivity App"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail retrieval with different factory modes."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
app_id = "test-app"
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that None is returned when app is not found."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "nonexistent-app"
|
||||
|
||||
# Mock retrieval instance returning None
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = None
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
|
||||
"""Test handling of empty dict response."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
app_id = "app-empty"
|
||||
|
||||
# Mock retrieval instance returning empty dict
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = {}
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail with complex model configuration."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "complex-app"
|
||||
|
||||
complex_model_config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"parameters": {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Complex App",
|
||||
model_config=complex_model_config,
|
||||
workflows=["workflow-1", "workflow-2"],
|
||||
tools=["tool-1", "tool-2", "tool-3"],
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result["model_config"] == complex_model_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
|
|
@ -0,0 +1,626 @@
|
|||
"""
|
||||
Comprehensive unit tests for SavedMessageService.
|
||||
|
||||
This test suite provides complete coverage of saved message operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Pagination (TestSavedMessageServicePagination)
|
||||
Tests saved message listing and pagination:
|
||||
- Pagination with valid user (Account and EndUser)
|
||||
- Pagination without user raises ValueError
|
||||
- Pagination with last_id parameter
|
||||
- Empty results when no saved messages exist
|
||||
- Integration with MessageService pagination
|
||||
|
||||
### 2. Save Operations (TestSavedMessageServiceSave)
|
||||
Tests saving messages:
|
||||
- Save message for Account user
|
||||
- Save message for EndUser
|
||||
- Save without user (no-op)
|
||||
- Prevent duplicate saves (idempotent)
|
||||
- Message validation through MessageService
|
||||
|
||||
### 3. Delete Operations (TestSavedMessageServiceDelete)
|
||||
Tests deleting saved messages:
|
||||
- Delete saved message for Account user
|
||||
- Delete saved message for EndUser
|
||||
- Delete without user (no-op)
|
||||
- Delete non-existent saved message (no-op)
|
||||
- Proper database cleanup
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked
|
||||
for fast, isolated unit tests
|
||||
- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and side effects
|
||||
(database operations, method calls)
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**User Types:**
|
||||
- Account: Workspace members (console users)
|
||||
- EndUser: API users (end users)
|
||||
|
||||
**Saved Messages:**
|
||||
- Users can save messages for later reference
|
||||
- Each user has their own saved message list
|
||||
- Saving is idempotent (duplicate saves ignored)
|
||||
- Deletion is safe (non-existent deletes ignored)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.model import App, EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
saved message operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock Account object.
|
||||
|
||||
Args:
|
||||
account_id: Unique identifier for the account
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Account object with specified attributes
|
||||
"""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock EndUser object.
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the end user
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock EndUser object with specified attributes
|
||||
"""
|
||||
user = create_autospec(EndUser, instance=True)
|
||||
user.id = user_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock App object.
|
||||
|
||||
Args:
|
||||
app_id: Unique identifier for the app
|
||||
tenant_id: Tenant/workspace identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock App object with specified attributes
|
||||
"""
|
||||
app = create_autospec(App, instance=True)
|
||||
app.id = app_id
|
||||
app.tenant_id = tenant_id
|
||||
app.name = kwargs.get("name", "Test App")
|
||||
app.mode = kwargs.get("mode", "chat")
|
||||
for key, value in kwargs.items():
|
||||
setattr(app, key, value)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-123",
|
||||
app_id: str = "app-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Message object.
|
||||
|
||||
Args:
|
||||
message_id: Unique identifier for the message
|
||||
app_id: Associated app identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Message object with specified attributes
|
||||
"""
|
||||
message = create_autospec(Message, instance=True)
|
||||
message.id = message_id
|
||||
message.app_id = app_id
|
||||
message.query = kwargs.get("query", "Test query")
|
||||
message.answer = kwargs.get("answer", "Test answer")
|
||||
message.created_at = kwargs.get("created_at", datetime.now(UTC))
|
||||
for key, value in kwargs.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def create_saved_message_mock(
|
||||
saved_message_id: str = "saved-123",
|
||||
app_id: str = "app-123",
|
||||
message_id: str = "msg-123",
|
||||
created_by: str = "user-123",
|
||||
created_by_role: str = "account",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock SavedMessage object.
|
||||
|
||||
Args:
|
||||
saved_message_id: Unique identifier for the saved message
|
||||
app_id: Associated app identifier
|
||||
message_id: Associated message identifier
|
||||
created_by: User who saved the message
|
||||
created_by_role: Role of the user ('account' or 'end_user')
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock SavedMessage object with specified attributes
|
||||
"""
|
||||
saved_message = create_autospec(SavedMessage, instance=True)
|
||||
saved_message.id = saved_message_id
|
||||
saved_message.app_id = app_id
|
||||
saved_message.message_id = message_id
|
||||
saved_message.created_by = created_by
|
||||
saved_message.created_by_role = created_by_role
|
||||
saved_message.created_at = kwargs.get("created_at", datetime.now(UTC))
|
||||
for key, value in kwargs.items():
|
||||
setattr(saved_message, key, value)
|
||||
return saved_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return SavedMessageServiceTestDataFactory
|
||||
|
||||
|
||||
class TestSavedMessageServicePagination:
|
||||
"""Test saved message pagination operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
|
||||
# Create saved messages for this user
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
saved_message_id=f"saved-{i}",
|
||||
app_id=app.id,
|
||||
message_id=f"msg-{i}",
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
mock_db_session.query.assert_called_once_with(SavedMessage)
|
||||
# Verify MessageService was called with correct message IDs
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
include_ids=["msg-0", "msg-1", "msg-2"],
|
||||
)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
# Create saved messages for this end user
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
saved_message_id=f"saved-{i}",
|
||||
app_id=app.id,
|
||||
message_id=f"msg-{i}",
|
||||
created_by=user.id,
|
||||
created_by_role="end_user",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify correct role was used in query
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
include_ids=["msg-0", "msg-1"],
|
||||
)
|
||||
|
||||
def test_pagination_without_user_raises_error(self, factory):
|
||||
"""Test that pagination without user raises ValueError."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with last_id parameter."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
last_id = "msg-last"
|
||||
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
message_id=f"msg-{i}",
|
||||
app_id=app.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify last_id was passed to MessageService
|
||||
mock_message_pagination.assert_called_once()
|
||||
call_args = mock_message_pagination.call_args
|
||||
assert call_args.kwargs["last_id"] == last_id
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination when user has no saved messages."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
|
||||
# Mock database query returning empty list
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify MessageService was called with empty include_ids
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
include_ids=[],
|
||||
)
|
||||
|
||||
|
||||
class TestSavedMessageServiceSave:
|
||||
"""Test save message operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message = factory.create_message_mock(message_id="msg-123", app_id=app.id)
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
saved_message = mock_db_session.add.call_args[0][0]
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by == user.id
|
||||
assert saved_message.created_by_role == "account"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock(message_id="msg-456", app_id=app.id)
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
saved_message = mock_db_session.add.call_args[0][0]
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by == user.id
|
||||
assert saved_message.created_by_role == "end_user"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that saving without user is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=None, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that saving an already saved message is idempotent."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-789"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
existing_saved = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = existing_saved
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert - no new saved message created
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
mock_get_message.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that save validates message exists through MessageService."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message = factory.create_message_mock()
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert - MessageService.get_message was called for validation
|
||||
mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
|
||||
class TestSavedMessageServiceDelete:
|
||||
"""Test delete saved message operations."""
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_saved_message_for_account(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-123"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message_id = "msg-456"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="end_user",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting without user is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=None, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting a non-existent saved message is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-nonexistent"
|
||||
|
||||
# Mock database query - no saved message found
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert - no deletion occurred
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
|
||||
"""Test that delete only removes the user's own saved message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user1 = factory.create_account_mock(account_id="user-1")
|
||||
message_id = "msg-shared"
|
||||
|
||||
# Mock database query - finds user1's saved message
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user1.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user1, message_id=message_id)
|
||||
|
||||
# Assert - only user1's saved message is deleted
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
# Verify the query filters by user
|
||||
assert mock_query.where.called
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,9 +1,9 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
|
|
@ -299,3 +299,154 @@ class TestToolTransformService:
|
|||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -23,7 +23,7 @@ const Empty = () => {
|
|||
return (
|
||||
<>
|
||||
<DefaultCards />
|
||||
<div className='absolute bottom-0 left-0 right-0 top-0 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'>
|
||||
<div className='absolute inset-0 z-20 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent pointer-events-none'>
|
||||
<span className='system-md-medium text-text-tertiary'>
|
||||
{t('app.newApp.noAppsFound')}
|
||||
</span>
|
||||
|
|
|
|||
|
|
@ -187,6 +187,19 @@ const GotoAnything: FC<Props> = ({
|
|||
}, {} as { [key: string]: SearchResult[] }),
|
||||
[searchResults])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCommandsMode)
|
||||
return
|
||||
|
||||
if (!searchResults.length)
|
||||
return
|
||||
|
||||
const currentValueExists = searchResults.some(result => `${result.type}-${result.id}` === cmdVal)
|
||||
|
||||
if (!currentValueExists)
|
||||
setCmdVal(`${searchResults[0].type}-${searchResults[0].id}`)
|
||||
}, [isCommandsMode, searchResults, cmdVal])
|
||||
|
||||
const emptyResult = useMemo(() => {
|
||||
if (searchResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
|
||||
return null
|
||||
|
|
@ -386,7 +399,7 @@ const GotoAnything: FC<Props> = ({
|
|||
<Command.Item
|
||||
key={`${result.type}-${result.id}`}
|
||||
value={`${result.type}-${result.id}`}
|
||||
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover'
|
||||
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt'
|
||||
onSelect={() => handleNavigate(result)}
|
||||
>
|
||||
{result.icon}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,12 @@ const Nav = ({
|
|||
`}>
|
||||
<Link href={link + (linkLastSearchParams && `?${linkLastSearchParams}`)}>
|
||||
<div
|
||||
onClick={() => setAppDetail()}
|
||||
onClick={(e) => {
|
||||
// Don't clear state if opening in new tab/window
|
||||
if (e.metaKey || e.ctrlKey || e.shiftKey || e.button !== 0)
|
||||
return
|
||||
setAppDetail()
|
||||
}}
|
||||
className={classNames(
|
||||
'flex h-7 cursor-pointer items-center rounded-[10px] px-2.5',
|
||||
isActivated ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text',
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ export type Collection = {
|
|||
timeout?: number
|
||||
sse_read_timeout?: number
|
||||
}
|
||||
// Workflow
|
||||
workflow_app_id?: string
|
||||
}
|
||||
|
||||
export type ToolParameter = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useEdges } from 'reactflow'
|
||||
|
|
@ -16,6 +17,10 @@ import {
|
|||
} from '@/app/components/workflow/hooks'
|
||||
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { useAllWorkflowTools } from '@/service/use-tools'
|
||||
import { canFindTool } from '@/utils'
|
||||
|
||||
type PanelOperatorPopupProps = {
|
||||
id: string
|
||||
|
|
@ -45,6 +50,14 @@ const PanelOperatorPopup = ({
|
|||
const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly
|
||||
const isChildNode = !!(data.isInIteration || data.isInLoop)
|
||||
|
||||
const { data: workflowTools } = useAllWorkflowTools()
|
||||
const isWorkflowTool = data.type === BlockEnum.Tool && data.provider_type === CollectionType.workflow
|
||||
const workflowAppId = useMemo(() => {
|
||||
if (!isWorkflowTool || !workflowTools || !data.provider_id) return undefined
|
||||
const workflowTool = workflowTools.find(item => canFindTool(item.id, data.provider_id))
|
||||
return workflowTool?.workflow_app_id
|
||||
}, [isWorkflowTool, workflowTools, data.provider_id])
|
||||
|
||||
return (
|
||||
<div className='w-[240px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl'>
|
||||
{
|
||||
|
|
@ -137,6 +150,22 @@ const PanelOperatorPopup = ({
|
|||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
isWorkflowTool && workflowAppId && (
|
||||
<>
|
||||
<div className='p-1'>
|
||||
<a
|
||||
href={`/app/${workflowAppId}/workflow`}
|
||||
target='_blank'
|
||||
className='flex h-8 cursor-pointer items-center rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover'
|
||||
>
|
||||
{t('workflow.panel.openWorkflow')}
|
||||
</a>
|
||||
</div>
|
||||
<div className='h-px bg-divider-regular'></div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
showHelpLink && nodeMetaData.helpLinkUri && (
|
||||
<>
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(optional & hidden)',
|
||||
goTo: 'Gehe zu',
|
||||
startNode: 'Startknoten',
|
||||
openWorkflow: 'Workflow öffnen',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -383,6 +383,7 @@ const translation = {
|
|||
userInputField: 'User Input Field',
|
||||
changeBlock: 'Change Node',
|
||||
helpLink: 'View Docs',
|
||||
openWorkflow: 'Open Workflow',
|
||||
about: 'About',
|
||||
createdBy: 'Created By ',
|
||||
nextStep: 'Next Step',
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opcional y oculto)',
|
||||
goTo: 'Ir a',
|
||||
startNode: 'Nodo de inicio',
|
||||
openWorkflow: 'Abrir flujo de trabajo',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(اختیاری و پنهان)',
|
||||
goTo: 'برو به',
|
||||
startNode: 'گره شروع',
|
||||
openWorkflow: 'باز کردن جریان کاری',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(optionnel et caché)',
|
||||
goTo: 'Aller à',
|
||||
startNode: 'Nœud de départ',
|
||||
openWorkflow: 'Ouvrir le flux de travail',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -386,6 +386,7 @@ const translation = {
|
|||
optional_and_hidden: '(वैकल्पिक और छिपा हुआ)',
|
||||
goTo: 'जाओ',
|
||||
startNode: 'प्रारंभ नोड',
|
||||
openWorkflow: 'वर्कफ़्लो खोलें',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -381,6 +381,7 @@ const translation = {
|
|||
goTo: 'Pergi ke',
|
||||
startNode: 'Mulai Node',
|
||||
scrollToSelectedNode: 'Gulir ke node yang dipilih',
|
||||
openWorkflow: 'Buka Alur Kerja',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -389,6 +389,7 @@ const translation = {
|
|||
optional_and_hidden: '(opzionale e nascosto)',
|
||||
goTo: 'Vai a',
|
||||
startNode: 'Nodo iniziale',
|
||||
openWorkflow: 'Apri flusso di lavoro',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -401,6 +401,7 @@ const translation = {
|
|||
minimize: '全画面を終了する',
|
||||
scrollToSelectedNode: '選択したノードまでスクロール',
|
||||
optional_and_hidden: '(オプションおよび非表示)',
|
||||
openWorkflow: 'ワークフローを開く',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -395,6 +395,7 @@ const translation = {
|
|||
optional_and_hidden: '(선택 사항 및 숨김)',
|
||||
goTo: '로 이동',
|
||||
startNode: '시작 노드',
|
||||
openWorkflow: '워크플로 열기',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opcjonalne i ukryte)',
|
||||
goTo: 'Idź do',
|
||||
startNode: 'Węzeł początkowy',
|
||||
openWorkflow: 'Otwórz przepływ pracy',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opcional & oculto)',
|
||||
goTo: 'Ir para',
|
||||
startNode: 'Iniciar Nó',
|
||||
openWorkflow: 'Abrir fluxo de trabalho',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(opțional și ascuns)',
|
||||
goTo: 'Du-te la',
|
||||
startNode: 'Nod de start',
|
||||
openWorkflow: 'Deschide fluxul de lucru',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(необязательно и скрыто)',
|
||||
goTo: 'Перейти к',
|
||||
startNode: 'Начальный узел',
|
||||
openWorkflow: 'Открыть рабочий процесс',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -381,6 +381,7 @@ const translation = {
|
|||
optional_and_hidden: '(neobvezno in skrito)',
|
||||
goTo: 'Pojdi na',
|
||||
startNode: 'Začetni vozel',
|
||||
openWorkflow: 'Odpri delovni tok',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(ตัวเลือก & ซ่อน)',
|
||||
goTo: 'ไปที่',
|
||||
startNode: 'เริ่มต้นโหนด',
|
||||
openWorkflow: 'เปิดเวิร์กโฟลว์',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(isteğe bağlı ve gizli)',
|
||||
goTo: 'Git',
|
||||
startNode: 'Başlangıç Düğümü',
|
||||
openWorkflow: 'İş Akışını Aç',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(необов\'язково & приховано)',
|
||||
goTo: 'Перейти до',
|
||||
startNode: 'Початковий вузол',
|
||||
openWorkflow: 'Відкрити робочий процес',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ const translation = {
|
|||
optional_and_hidden: '(tùy chọn & ẩn)',
|
||||
goTo: 'Đi tới',
|
||||
startNode: 'Nút Bắt đầu',
|
||||
openWorkflow: 'Mở quy trình làm việc',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
|
|
@ -383,6 +383,7 @@ const translation = {
|
|||
userInputField: '用户输入字段',
|
||||
changeBlock: '更改节点',
|
||||
helpLink: '查看帮助文档',
|
||||
openWorkflow: '打开工作流',
|
||||
about: '关于',
|
||||
createdBy: '作者',
|
||||
nextStep: '下一步',
|
||||
|
|
|
|||
|
|
@ -379,6 +379,7 @@ const translation = {
|
|||
optional_and_hidden: '(可選且隱藏)',
|
||||
goTo: '前往',
|
||||
startNode: '起始節點',
|
||||
openWorkflow: '打開工作流程',
|
||||
},
|
||||
nodes: {
|
||||
common: {
|
||||
|
|
|
|||
Loading…
Reference in New Issue