mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
evaluations
This commit is contained in:
parent
4f2cd40498
commit
bea428e308
@ -115,6 +115,12 @@ from .explore import (
|
|||||||
trial,
|
trial,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import evaluation controllers
|
||||||
|
from .evaluation import evaluation
|
||||||
|
|
||||||
|
# Import snippet controllers
|
||||||
|
from .snippets import snippet_workflow
|
||||||
|
|
||||||
# Import tag controllers
|
# Import tag controllers
|
||||||
from .tag import tags
|
from .tag import tags
|
||||||
|
|
||||||
@ -128,6 +134,7 @@ from .workspace import (
|
|||||||
model_providers,
|
model_providers,
|
||||||
models,
|
models,
|
||||||
plugin,
|
plugin,
|
||||||
|
snippets,
|
||||||
tool_providers,
|
tool_providers,
|
||||||
trigger_providers,
|
trigger_providers,
|
||||||
workspace,
|
workspace,
|
||||||
@ -165,6 +172,7 @@ __all__ = [
|
|||||||
"datasource_content_preview",
|
"datasource_content_preview",
|
||||||
"email_register",
|
"email_register",
|
||||||
"endpoint",
|
"endpoint",
|
||||||
|
"evaluation",
|
||||||
"extension",
|
"extension",
|
||||||
"external",
|
"external",
|
||||||
"feature",
|
"feature",
|
||||||
@ -197,6 +205,8 @@ __all__ = [
|
|||||||
"saved_message",
|
"saved_message",
|
||||||
"setup",
|
"setup",
|
||||||
"site",
|
"site",
|
||||||
|
"snippet_workflow",
|
||||||
|
"snippets",
|
||||||
"spec",
|
"spec",
|
||||||
"statistic",
|
"statistic",
|
||||||
"tags",
|
"tags",
|
||||||
|
|||||||
1
api/controllers/console/evaluation/__init__.py
Normal file
1
api/controllers/console/evaluation/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Evaluation controller module
|
||||||
288
api/controllers/console/evaluation/evaluation.py
Normal file
288
api/controllers/console/evaluation/evaluation.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar, Union
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
|
from controllers.console import console_ns
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
edit_permission_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
from models import App
|
||||||
|
from models.snippet import CustomizedSnippet
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
# Valid evaluation target types
|
||||||
|
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||||
|
|
||||||
|
|
||||||
|
class VersionQuery(BaseModel):
|
||||||
|
"""Query parameters for version endpoint."""
|
||||||
|
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
VersionQuery,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Response field definitions
|
||||||
|
file_info_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluation_log_fields = {
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"test_file": fields.Nested(
|
||||||
|
console_ns.model(
|
||||||
|
"EvaluationTestFile",
|
||||||
|
file_info_fields,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"result_file": fields.Nested(
|
||||||
|
console_ns.model(
|
||||||
|
"EvaluationResultFile",
|
||||||
|
file_info_fields,
|
||||||
|
),
|
||||||
|
allow_null=True,
|
||||||
|
),
|
||||||
|
"version": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluation_log_list_model = console_ns.model(
|
||||||
|
"EvaluationLogList",
|
||||||
|
{
|
||||||
|
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
customized_matrix_fields = {
|
||||||
|
"evaluation_workflow_id": fields.String,
|
||||||
|
"input_fields": fields.Raw,
|
||||||
|
"output_fields": fields.Raw,
|
||||||
|
}
|
||||||
|
|
||||||
|
condition_fields = {
|
||||||
|
"name": fields.List(fields.String),
|
||||||
|
"comparison_operator": fields.String,
|
||||||
|
"value": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
judgement_conditions_fields = {
|
||||||
|
"logical_operator": fields.String,
|
||||||
|
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluation_detail_fields = {
|
||||||
|
"evaluation_model": fields.String,
|
||||||
|
"evaluation_model_provider": fields.String,
|
||||||
|
"customized_matrix": fields.Nested(
|
||||||
|
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
|
||||||
|
allow_null=True,
|
||||||
|
),
|
||||||
|
"judgement_conditions": fields.Nested(
|
||||||
|
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
|
||||||
|
allow_null=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||||
|
|
||||||
|
|
||||||
|
def get_evaluation_target(view_func: Callable[P, R]):
|
||||||
|
"""
|
||||||
|
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||||
|
|
||||||
|
Validates the target_type parameter and fetches the corresponding
|
||||||
|
model (App or CustomizedSnippet) with tenant isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
target_type = kwargs.get("evaluate_target_type")
|
||||||
|
target_id = kwargs.get("evaluate_target_id")
|
||||||
|
|
||||||
|
if target_type not in EVALUATE_TARGET_TYPES:
|
||||||
|
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||||
|
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
target_id = str(target_id)
|
||||||
|
|
||||||
|
# Remove path parameters
|
||||||
|
del kwargs["evaluate_target_type"]
|
||||||
|
del kwargs["evaluate_target_id"]
|
||||||
|
|
||||||
|
target: Union[App, CustomizedSnippet] | None = None
|
||||||
|
|
||||||
|
if target_type == "app":
|
||||||
|
target = (
|
||||||
|
db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||||
|
)
|
||||||
|
elif target_type == "snippets":
|
||||||
|
target = (
|
||||||
|
db.session.query(CustomizedSnippet)
|
||||||
|
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not target:
|
||||||
|
raise NotFound(f"{str(target_type)} not found")
|
||||||
|
|
||||||
|
kwargs["target"] = target
|
||||||
|
kwargs["target_type"] = target_type
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||||
|
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||||
|
@console_ns.doc("download_evaluation_dataset_template")
|
||||||
|
@console_ns.response(200, "Template download URL generated successfully")
|
||||||
|
@console_ns.response(404, "Target not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_evaluation_target
|
||||||
|
@edit_permission_required
|
||||||
|
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||||
|
"""
|
||||||
|
Download evaluation dataset template.
|
||||||
|
|
||||||
|
Generates a download URL for the evaluation dataset template
|
||||||
|
based on the target type (app or snippets).
|
||||||
|
"""
|
||||||
|
# TODO: Implement actual template generation logic
|
||||||
|
# This is a placeholder implementation
|
||||||
|
return {
|
||||||
|
"download_url": f"/api/evaluation/{target_type}/{target.id}/template.csv",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||||
|
class EvaluationDetailApi(Resource):
|
||||||
|
@console_ns.doc("get_evaluation_detail")
|
||||||
|
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||||
|
@console_ns.response(404, "Target not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_evaluation_target
|
||||||
|
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||||
|
"""
|
||||||
|
Get evaluation details for the target.
|
||||||
|
|
||||||
|
Returns evaluation configuration including model settings,
|
||||||
|
customized matrix, and judgement conditions.
|
||||||
|
"""
|
||||||
|
# TODO: Implement actual evaluation detail retrieval
|
||||||
|
# This is a placeholder implementation
|
||||||
|
return {
|
||||||
|
"evaluation_model": None,
|
||||||
|
"evaluation_model_provider": None,
|
||||||
|
"customized_matrix": None,
|
||||||
|
"judgement_conditions": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||||
|
class EvaluationLogsApi(Resource):
|
||||||
|
@console_ns.doc("get_evaluation_logs")
|
||||||
|
@console_ns.response(200, "Evaluation logs retrieved successfully", evaluation_log_list_model)
|
||||||
|
@console_ns.response(404, "Target not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_evaluation_target
|
||||||
|
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||||
|
"""
|
||||||
|
Get offline evaluation logs for the target.
|
||||||
|
|
||||||
|
Returns a list of evaluation runs with test files,
|
||||||
|
result files, and version information.
|
||||||
|
"""
|
||||||
|
# TODO: Implement actual evaluation logs retrieval
|
||||||
|
# This is a placeholder implementation
|
||||||
|
return {
|
||||||
|
"data": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||||
|
class EvaluationFileDownloadApi(Resource):
|
||||||
|
@console_ns.doc("download_evaluation_file")
|
||||||
|
@console_ns.response(200, "File download URL generated successfully")
|
||||||
|
@console_ns.response(404, "Target or file not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_evaluation_target
|
||||||
|
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||||
|
"""
|
||||||
|
Download evaluation test file or result file.
|
||||||
|
|
||||||
|
Returns file information and download URL for the specified file.
|
||||||
|
"""
|
||||||
|
file_id = str(file_id)
|
||||||
|
|
||||||
|
# TODO: Implement actual file download logic
|
||||||
|
# This is a placeholder implementation
|
||||||
|
return {
|
||||||
|
"created_at": None,
|
||||||
|
"created_by": None,
|
||||||
|
"test_file": None,
|
||||||
|
"result_file": None,
|
||||||
|
"version": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||||
|
class EvaluationVersionApi(Resource):
|
||||||
|
@console_ns.doc("get_evaluation_version_detail")
|
||||||
|
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||||
|
@console_ns.response(200, "Version details retrieved successfully")
|
||||||
|
@console_ns.response(404, "Target or version not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_evaluation_target
|
||||||
|
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||||
|
"""
|
||||||
|
Get evaluation target version details.
|
||||||
|
|
||||||
|
Returns the workflow graph for the specified version.
|
||||||
|
"""
|
||||||
|
version = request.args.get("version")
|
||||||
|
|
||||||
|
if not version:
|
||||||
|
return {"message": "version parameter is required"}, 400
|
||||||
|
|
||||||
|
# TODO: Implement actual version detail retrieval
|
||||||
|
# For now, return the current graph if available
|
||||||
|
graph = {}
|
||||||
|
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||||
|
graph = target.graph_dict
|
||||||
|
|
||||||
|
return {
|
||||||
|
"graph": graph,
|
||||||
|
}
|
||||||
75
api/controllers/console/snippets/payloads.py
Normal file
75
api/controllers/console/snippets/payloads.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SnippetListQuery(BaseModel):
|
||||||
|
"""Query parameters for listing snippets."""
|
||||||
|
|
||||||
|
page: int = Field(default=1, ge=1, le=99999)
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
keyword: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IconInfo(BaseModel):
|
||||||
|
"""Icon information model."""
|
||||||
|
|
||||||
|
icon: str | None = None
|
||||||
|
icon_type: Literal["emoji", "image"] | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
icon_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InputFieldDefinition(BaseModel):
|
||||||
|
"""Input field definition for snippet parameters."""
|
||||||
|
|
||||||
|
default: str | None = None
|
||||||
|
hint: bool | None = None
|
||||||
|
label: str | None = None
|
||||||
|
max_length: int | None = None
|
||||||
|
options: list[str] | None = None
|
||||||
|
placeholder: str | None = None
|
||||||
|
required: bool | None = None
|
||||||
|
type: str | None = None # e.g., "text-input"
|
||||||
|
|
||||||
|
|
||||||
|
class CreateSnippetPayload(BaseModel):
|
||||||
|
"""Payload for creating a new snippet."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
description: str | None = Field(default=None, max_length=2000)
|
||||||
|
type: Literal["node", "group"] = "node"
|
||||||
|
icon_info: IconInfo | None = None
|
||||||
|
graph: dict[str, Any] | None = None
|
||||||
|
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSnippetPayload(BaseModel):
|
||||||
|
"""Payload for updating a snippet."""
|
||||||
|
|
||||||
|
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||||
|
description: str | None = Field(default=None, max_length=2000)
|
||||||
|
icon_info: IconInfo | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SnippetDraftSyncPayload(BaseModel):
|
||||||
|
"""Payload for syncing snippet draft workflow."""
|
||||||
|
|
||||||
|
graph: dict[str, Any]
|
||||||
|
hash: str | None = None
|
||||||
|
environment_variables: list[dict[str, Any]] | None = None
|
||||||
|
conversation_variables: list[dict[str, Any]] | None = None
|
||||||
|
input_variables: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunQuery(BaseModel):
|
||||||
|
"""Query parameters for workflow runs."""
|
||||||
|
|
||||||
|
last_id: str | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class PublishWorkflowPayload(BaseModel):
|
||||||
|
"""Payload for publishing snippet workflow."""
|
||||||
|
|
||||||
|
knowledge_base_setting: dict[str, Any] | None = None
|
||||||
306
api/controllers/console/snippets/snippet_workflow.py
Normal file
306
api/controllers/console/snippets/snippet_workflow.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, marshal_with
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
|
from controllers.console import console_ns
|
||||||
|
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||||
|
from controllers.console.app.workflow import workflow_model
|
||||||
|
from controllers.console.app.workflow_run import (
|
||||||
|
workflow_run_detail_model,
|
||||||
|
workflow_run_node_execution_list_model,
|
||||||
|
workflow_run_pagination_model,
|
||||||
|
)
|
||||||
|
from controllers.console.snippets.payloads import (
|
||||||
|
PublishWorkflowPayload,
|
||||||
|
SnippetDraftSyncPayload,
|
||||||
|
WorkflowRunQuery,
|
||||||
|
)
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
edit_permission_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from factories import variable_factory
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
from models.snippet import CustomizedSnippet
|
||||||
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
|
from services.snippet_service import SnippetService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
# Register Pydantic models with Swagger
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
SnippetDraftSyncPayload,
|
||||||
|
WorkflowRunQuery,
|
||||||
|
PublishWorkflowPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SnippetNotFoundError(Exception):
|
||||||
|
"""Snippet not found error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_snippet(view_func: Callable[P, R]):
|
||||||
|
"""Decorator to fetch and validate snippet access."""
|
||||||
|
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
if not kwargs.get("snippet_id"):
|
||||||
|
raise ValueError("missing snippet_id in path parameters")
|
||||||
|
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
snippet_id = str(kwargs.get("snippet_id"))
|
||||||
|
del kwargs["snippet_id"]
|
||||||
|
|
||||||
|
snippet = SnippetService.get_snippet_by_id(
|
||||||
|
snippet_id=snippet_id,
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not snippet:
|
||||||
|
raise NotFound("Snippet not found")
|
||||||
|
|
||||||
|
kwargs["snippet"] = snippet
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||||
|
class SnippetDraftWorkflowApi(Resource):
|
||||||
|
@console_ns.doc("get_snippet_draft_workflow")
|
||||||
|
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||||
|
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@edit_permission_required
|
||||||
|
@marshal_with(workflow_model)
|
||||||
|
def get(self, snippet: CustomizedSnippet):
|
||||||
|
"""Get draft workflow for snippet."""
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
@console_ns.doc("sync_snippet_draft_workflow")
|
||||||
|
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||||
|
@console_ns.response(200, "Draft workflow synced successfully")
|
||||||
|
@console_ns.response(400, "Hash mismatch")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@edit_permission_required
|
||||||
|
def post(self, snippet: CustomizedSnippet):
|
||||||
|
"""Sync draft workflow for snippet."""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
environment_variables_list = payload.environment_variables or []
|
||||||
|
environment_variables = [
|
||||||
|
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||||
|
]
|
||||||
|
conversation_variables_list = payload.conversation_variables or []
|
||||||
|
conversation_variables = [
|
||||||
|
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||||
|
]
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
workflow = snippet_service.sync_draft_workflow(
|
||||||
|
snippet=snippet,
|
||||||
|
graph=payload.graph,
|
||||||
|
unique_hash=payload.hash,
|
||||||
|
account=current_user,
|
||||||
|
environment_variables=environment_variables,
|
||||||
|
conversation_variables=conversation_variables,
|
||||||
|
input_variables=payload.input_variables,
|
||||||
|
)
|
||||||
|
except WorkflowHashNotEqualError:
|
||||||
|
raise DraftWorkflowNotSync()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": "success",
|
||||||
|
"hash": workflow.unique_hash,
|
||||||
|
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||||
|
class SnippetDraftConfigApi(Resource):
|
||||||
|
@console_ns.doc("get_snippet_draft_config")
|
||||||
|
@console_ns.response(200, "Draft config retrieved successfully")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@edit_permission_required
|
||||||
|
def get(self, snippet: CustomizedSnippet):
|
||||||
|
"""Get snippet draft workflow configuration limits."""
|
||||||
|
return {
|
||||||
|
"parallel_depth_limit": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||||
|
class SnippetPublishedWorkflowApi(Resource):
|
||||||
|
@console_ns.doc("get_snippet_published_workflow")
|
||||||
|
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||||
|
@console_ns.response(404, "Snippet not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@edit_permission_required
|
||||||
|
@marshal_with(workflow_model)
|
||||||
|
def get(self, snippet: CustomizedSnippet):
|
||||||
|
"""Get published workflow for snippet."""
|
||||||
|
if not snippet.is_published:
|
||||||
|
return None
|
||||||
|
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
@console_ns.doc("publish_snippet_workflow")
|
||||||
|
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||||
|
@console_ns.response(200, "Workflow published successfully")
|
||||||
|
@console_ns.response(400, "No draft workflow found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@edit_permission_required
|
||||||
|
def post(self, snippet: CustomizedSnippet):
|
||||||
|
"""Publish snippet workflow."""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
snippet = session.merge(snippet)
|
||||||
|
try:
|
||||||
|
workflow = snippet_service.publish_workflow(
|
||||||
|
session=session,
|
||||||
|
snippet=snippet,
|
||||||
|
account=current_user,
|
||||||
|
)
|
||||||
|
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||||
|
session.commit()
|
||||||
|
except ValueError as e:
|
||||||
|
return {"message": str(e)}, 400
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": "success",
|
||||||
|
"created_at": workflow_created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||||
|
class SnippetDefaultBlockConfigsApi(Resource):
|
||||||
|
@console_ns.doc("get_snippet_default_block_configs")
|
||||||
|
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@edit_permission_required
|
||||||
|
def get(self, snippet: CustomizedSnippet):
|
||||||
|
"""Get default block configurations for snippet workflow."""
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
return snippet_service.get_default_block_configs()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||||
|
class SnippetWorkflowRunsApi(Resource):
|
||||||
|
@console_ns.doc("list_snippet_workflow_runs")
|
||||||
|
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@marshal_with(workflow_run_pagination_model)
|
||||||
|
def get(self, snippet: CustomizedSnippet):
|
||||||
|
"""List workflow runs for snippet."""
|
||||||
|
query = WorkflowRunQuery.model_validate(
|
||||||
|
{
|
||||||
|
"last_id": request.args.get("last_id"),
|
||||||
|
"limit": request.args.get("limit", type=int, default=20),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
args = {
|
||||||
|
"last_id": query.last_id,
|
||||||
|
"limit": query.limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||||
|
class SnippetWorkflowRunDetailApi(Resource):
|
||||||
|
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||||
|
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||||
|
@console_ns.response(404, "Workflow run not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@marshal_with(workflow_run_detail_model)
|
||||||
|
def get(self, snippet: CustomizedSnippet, run_id):
|
||||||
|
"""Get workflow run detail for snippet."""
|
||||||
|
run_id = str(run_id)
|
||||||
|
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||||
|
|
||||||
|
if not workflow_run:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||||
|
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||||
|
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||||
|
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_snippet
|
||||||
|
@marshal_with(workflow_run_node_execution_list_model)
|
||||||
|
def get(self, snippet: CustomizedSnippet, run_id):
|
||||||
|
"""List node executions for a workflow run."""
|
||||||
|
run_id = str(run_id)
|
||||||
|
|
||||||
|
snippet_service = SnippetService()
|
||||||
|
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||||
|
snippet=snippet,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"data": node_executions}
|
||||||
202
api/controllers/console/workspace/snippets.py
Normal file
202
api/controllers/console/workspace/snippets.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, marshal, marshal_with
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
|
from controllers.console import console_ns
|
||||||
|
from controllers.console.snippets.payloads import (
|
||||||
|
CreateSnippetPayload,
|
||||||
|
SnippetListQuery,
|
||||||
|
UpdateSnippetPayload,
|
||||||
|
)
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
edit_permission_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||||
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
from models.snippet import SnippetType
|
||||||
|
from services.snippet_service import SnippetService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Register Pydantic models with Swagger
|
||||||
|
register_schema_models(
|
||||||
|
console_ns,
|
||||||
|
SnippetListQuery,
|
||||||
|
CreateSnippetPayload,
|
||||||
|
UpdateSnippetPayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create namespace models for marshaling
|
||||||
|
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||||
|
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||||
|
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/workspaces/current/customized-snippets")
|
||||||
|
class CustomizedSnippetsApi(Resource):
|
||||||
|
@console_ns.doc("list_customized_snippets")
|
||||||
|
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||||
|
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
"""List customized snippets with pagination and search."""
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
query_params = request.args.to_dict()
|
||||||
|
query = SnippetListQuery.model_validate(query_params)
|
||||||
|
|
||||||
|
snippets, total, has_more = SnippetService.get_snippets(
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
page=query.page,
|
||||||
|
limit=query.limit,
|
||||||
|
keyword=query.keyword,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": marshal(snippets, snippet_list_fields),
|
||||||
|
"page": query.page,
|
||||||
|
"limit": query.limit,
|
||||||
|
"total": total,
|
||||||
|
"has_more": has_more,
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
@console_ns.doc("create_customized_snippet")
|
||||||
|
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||||
|
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||||
|
@console_ns.response(400, "Invalid request or name already exists")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
|
def post(self):
|
||||||
|
"""Create a new customized snippet."""
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
snippet_type = SnippetType(payload.type)
|
||||||
|
except ValueError:
|
||||||
|
snippet_type = SnippetType.NODE
|
||||||
|
|
||||||
|
try:
|
||||||
|
snippet = SnippetService.create_snippet(
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
name=payload.name,
|
||||||
|
description=payload.description,
|
||||||
|
snippet_type=snippet_type,
|
||||||
|
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||||
|
graph=payload.graph,
|
||||||
|
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||||
|
account=current_user,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return {"message": str(e)}, 400
|
||||||
|
|
||||||
|
return marshal(snippet, snippet_fields), 201
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||||
|
class CustomizedSnippetDetailApi(Resource):
|
||||||
|
@console_ns.doc("get_customized_snippet")
|
||||||
|
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||||
|
@console_ns.response(404, "Snippet not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, snippet_id: str):
|
||||||
|
"""Get customized snippet details."""
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
snippet = SnippetService.get_snippet_by_id(
|
||||||
|
snippet_id=str(snippet_id),
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not snippet:
|
||||||
|
raise NotFound("Snippet not found")
|
||||||
|
|
||||||
|
return marshal(snippet, snippet_fields), 200
|
||||||
|
|
||||||
|
@console_ns.doc("update_customized_snippet")
|
||||||
|
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||||
|
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||||
|
@console_ns.response(400, "Invalid request or name already exists")
|
||||||
|
@console_ns.response(404, "Snippet not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
|
def patch(self, snippet_id: str):
|
||||||
|
"""Update customized snippet."""
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
snippet = SnippetService.get_snippet_by_id(
|
||||||
|
snippet_id=str(snippet_id),
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not snippet:
|
||||||
|
raise NotFound("Snippet not found")
|
||||||
|
|
||||||
|
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||||
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||||
|
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||||
|
|
||||||
|
if not update_data:
|
||||||
|
return {"message": "No valid fields to update"}, 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
snippet = session.merge(snippet)
|
||||||
|
snippet = SnippetService.update_snippet(
|
||||||
|
session=session,
|
||||||
|
snippet=snippet,
|
||||||
|
account_id=current_user.id,
|
||||||
|
data=update_data,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
except ValueError as e:
|
||||||
|
return {"message": str(e)}, 400
|
||||||
|
|
||||||
|
return marshal(snippet, snippet_fields), 200
|
||||||
|
|
||||||
|
@console_ns.doc("delete_customized_snippet")
|
||||||
|
@console_ns.response(204, "Snippet deleted successfully")
|
||||||
|
@console_ns.response(404, "Snippet not found")
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
|
def delete(self, snippet_id: str):
|
||||||
|
"""Delete customized snippet."""
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
snippet = SnippetService.get_snippet_by_id(
|
||||||
|
snippet_id=str(snippet_id),
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not snippet:
|
||||||
|
raise NotFound("Snippet not found")
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
snippet = session.merge(snippet)
|
||||||
|
SnippetService.delete_snippet(
|
||||||
|
session=session,
|
||||||
|
snippet=snippet,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return "", 204
|
||||||
45
api/fields/snippet_fields.py
Normal file
45
api/fields/snippet_fields.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from flask_restx import fields
|
||||||
|
|
||||||
|
from fields.member_fields import simple_account_fields
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
# Snippet list item fields (lightweight for list display)
|
||||||
|
snippet_list_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"type": fields.String,
|
||||||
|
"version": fields.Integer,
|
||||||
|
"use_count": fields.Integer,
|
||||||
|
"is_published": fields.Boolean,
|
||||||
|
"icon_info": fields.Raw,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Full snippet fields (includes creator info and graph data)
|
||||||
|
snippet_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"type": fields.String,
|
||||||
|
"version": fields.Integer,
|
||||||
|
"use_count": fields.Integer,
|
||||||
|
"is_published": fields.Boolean,
|
||||||
|
"icon_info": fields.Raw,
|
||||||
|
"graph": fields.Raw(attribute="graph_dict"),
|
||||||
|
"input_fields": fields.Raw(attribute="input_fields_list"),
|
||||||
|
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pagination response fields
|
||||||
|
snippet_pagination_fields = {
|
||||||
|
"data": fields.List(fields.Nested(snippet_list_fields)),
|
||||||
|
"page": fields.Integer,
|
||||||
|
"limit": fields.Integer,
|
||||||
|
"total": fields.Integer,
|
||||||
|
"has_more": fields.Boolean,
|
||||||
|
}
|
||||||
@ -0,0 +1,83 @@
|
|||||||
|
"""add_customized_snippets_table
|
||||||
|
|
||||||
|
Revision ID: 1c05e80d2380
|
||||||
|
Revises: 788d3099ae3a
|
||||||
|
Create Date: 2026-01-29 12:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pg(conn):
|
||||||
|
return conn.dialect.name == "postgresql"
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "1c05e80d2380"
|
||||||
|
down_revision = "788d3099ae3a"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
if _is_pg(conn):
|
||||||
|
op.create_table(
|
||||||
|
"customized_snippets",
|
||||||
|
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||||
|
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||||
|
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||||
|
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||||
|
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("graph", sa.Text(), nullable=True),
|
||||||
|
sa.Column("input_fields", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||||
|
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
op.create_table(
|
||||||
|
"customized_snippets",
|
||||||
|
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("description", models.types.LongText(), nullable=True),
|
||||||
|
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||||
|
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||||
|
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||||
|
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("graph", models.types.LongText(), nullable=True),
|
||||||
|
sa.Column("input_fields", models.types.LongText(), nullable=True),
|
||||||
|
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||||
|
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||||
|
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||||
|
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||||
|
batch_op.drop_index("customized_snippet_tenant_idx")
|
||||||
|
|
||||||
|
op.drop_table("customized_snippets")
|
||||||
@ -79,6 +79,7 @@ from .provider import (
|
|||||||
TenantDefaultModel,
|
TenantDefaultModel,
|
||||||
TenantPreferredModelProvider,
|
TenantPreferredModelProvider,
|
||||||
)
|
)
|
||||||
|
from .snippet import CustomizedSnippet, SnippetType
|
||||||
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||||
from .task import CeleryTask, CeleryTaskSet
|
from .task import CeleryTask, CeleryTaskSet
|
||||||
from .tools import (
|
from .tools import (
|
||||||
@ -138,6 +139,7 @@ __all__ = [
|
|||||||
"Conversation",
|
"Conversation",
|
||||||
"ConversationVariable",
|
"ConversationVariable",
|
||||||
"CreatorUserRole",
|
"CreatorUserRole",
|
||||||
|
"CustomizedSnippet",
|
||||||
"DataSourceApiKeyAuthBinding",
|
"DataSourceApiKeyAuthBinding",
|
||||||
"DataSourceOauthBinding",
|
"DataSourceOauthBinding",
|
||||||
"Dataset",
|
"Dataset",
|
||||||
@ -179,6 +181,7 @@ __all__ = [
|
|||||||
"RecommendedApp",
|
"RecommendedApp",
|
||||||
"SavedMessage",
|
"SavedMessage",
|
||||||
"Site",
|
"Site",
|
||||||
|
"SnippetType",
|
||||||
"Tag",
|
"Tag",
|
||||||
"TagBinding",
|
"TagBinding",
|
||||||
"Tenant",
|
"Tenant",
|
||||||
|
|||||||
96
api/models/snippet.py
Normal file
96
api/models/snippet.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import DateTime, String, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from libs.uuid_utils import uuidv7
|
||||||
|
|
||||||
|
from .account import Account
|
||||||
|
from .base import Base
|
||||||
|
from .engine import db
|
||||||
|
from .types import AdjustedJSON, LongText, StringUUID
|
||||||
|
|
||||||
|
|
||||||
|
class SnippetType(StrEnum):
|
||||||
|
"""Snippet Type Enum"""
|
||||||
|
|
||||||
|
NODE = "node"
|
||||||
|
GROUP = "group"
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizedSnippet(Base):
|
||||||
|
"""
|
||||||
|
Customized Snippet Model
|
||||||
|
|
||||||
|
Stores reusable workflow components (nodes or node groups) that can be
|
||||||
|
shared across applications within a workspace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "customized_snippets"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||||
|
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
|
||||||
|
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||||
|
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
|
||||||
|
|
||||||
|
# Workflow reference for published version
|
||||||
|
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||||
|
|
||||||
|
# State flags
|
||||||
|
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
|
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
|
||||||
|
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||||
|
|
||||||
|
# Visual customization
|
||||||
|
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
|
||||||
|
|
||||||
|
# Snippet configuration (stored as JSON text)
|
||||||
|
graph: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||||
|
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||||
|
|
||||||
|
# Audit fields
|
||||||
|
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_dict(self) -> dict[str, Any]:
|
||||||
|
"""Parse graph JSON to dict."""
|
||||||
|
return json.loads(self.graph) if self.graph else {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_fields_list(self) -> list[dict[str, Any]]:
|
||||||
|
"""Parse input_fields JSON to list."""
|
||||||
|
return json.loads(self.input_fields) if self.input_fields else []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def created_by_account(self) -> Account | None:
|
||||||
|
"""Get the account that created this snippet."""
|
||||||
|
if self.created_by:
|
||||||
|
return db.session.get(Account, self.created_by)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def updated_by_account(self) -> Account | None:
|
||||||
|
"""Get the account that last updated this snippet."""
|
||||||
|
if self.updated_by:
|
||||||
|
return db.session.get(Account, self.updated_by)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version_str(self) -> str:
|
||||||
|
"""Get version as string for API response."""
|
||||||
|
return str(self.version)
|
||||||
@ -65,6 +65,7 @@ class WorkflowType(StrEnum):
|
|||||||
WORKFLOW = "workflow"
|
WORKFLOW = "workflow"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
RAG_PIPELINE = "rag-pipeline"
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
SNIPPET = "snippet"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowType":
|
def value_of(cls, value: str) -> "WorkflowType":
|
||||||
|
|||||||
542
api/services/snippet_service.py
Normal file
542
api/services/snippet_service.py
Normal file
@ -0,0 +1,542 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.variables.variables import VariableBase
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
|
from models import Account
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
from models.snippet import CustomizedSnippet, SnippetType
|
||||||
|
from models.workflow import (
|
||||||
|
Workflow,
|
||||||
|
WorkflowNodeExecutionModel,
|
||||||
|
WorkflowRun,
|
||||||
|
WorkflowType,
|
||||||
|
)
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SnippetService:
|
||||||
|
"""Service for managing customized snippets."""
|
||||||
|
|
||||||
|
def __init__(self, session_maker: sessionmaker | None = None):
|
||||||
|
"""Initialize SnippetService with repository dependencies."""
|
||||||
|
if session_maker is None:
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||||
|
session_maker
|
||||||
|
)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
|
# --- CRUD Operations ---
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_snippets(
|
||||||
|
*,
|
||||||
|
tenant_id: str,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 20,
|
||||||
|
keyword: str | None = None,
|
||||||
|
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
|
||||||
|
"""
|
||||||
|
Get paginated list of snippets with optional search.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param page: Page number (1-indexed)
|
||||||
|
:param limit: Number of items per page
|
||||||
|
:param keyword: Optional search keyword for name/description
|
||||||
|
:return: Tuple of (snippets list, total count, has_more flag)
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(CustomizedSnippet)
|
||||||
|
.where(CustomizedSnippet.tenant_id == tenant_id)
|
||||||
|
.order_by(CustomizedSnippet.created_at.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
if keyword:
|
||||||
|
stmt = stmt.where(
|
||||||
|
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||||
|
total = db.session.scalar(count_stmt) or 0
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||||
|
snippets = list(db.session.scalars(stmt).all())
|
||||||
|
|
||||||
|
has_more = len(snippets) > limit
|
||||||
|
if has_more:
|
||||||
|
snippets = snippets[:-1]
|
||||||
|
|
||||||
|
return snippets, total, has_more
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_snippet_by_id(
|
||||||
|
*,
|
||||||
|
snippet_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> CustomizedSnippet | None:
|
||||||
|
"""
|
||||||
|
Get snippet by ID with tenant isolation.
|
||||||
|
|
||||||
|
:param snippet_id: Snippet ID
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:return: CustomizedSnippet or None
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
db.session.query(CustomizedSnippet)
|
||||||
|
.where(
|
||||||
|
CustomizedSnippet.id == snippet_id,
|
||||||
|
CustomizedSnippet.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_snippet(
|
||||||
|
*,
|
||||||
|
tenant_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str | None,
|
||||||
|
snippet_type: SnippetType,
|
||||||
|
icon_info: dict | None,
|
||||||
|
graph: dict | None,
|
||||||
|
input_fields: list[dict] | None,
|
||||||
|
account: Account,
|
||||||
|
) -> CustomizedSnippet:
|
||||||
|
"""
|
||||||
|
Create a new snippet.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param name: Snippet name (must be unique per tenant)
|
||||||
|
:param description: Snippet description
|
||||||
|
:param snippet_type: Type of snippet (node or group)
|
||||||
|
:param icon_info: Icon information
|
||||||
|
:param graph: Workflow graph structure
|
||||||
|
:param input_fields: Input field definitions
|
||||||
|
:param account: Creator account
|
||||||
|
:return: Created CustomizedSnippet
|
||||||
|
:raises ValueError: If name already exists
|
||||||
|
"""
|
||||||
|
# Check if name already exists for this tenant
|
||||||
|
existing = (
|
||||||
|
db.session.query(CustomizedSnippet)
|
||||||
|
.where(
|
||||||
|
CustomizedSnippet.tenant_id == tenant_id,
|
||||||
|
CustomizedSnippet.name == name,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"Snippet with name '{name}' already exists")
|
||||||
|
|
||||||
|
snippet = CustomizedSnippet(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=name,
|
||||||
|
description=description or "",
|
||||||
|
type=snippet_type.value,
|
||||||
|
icon_info=icon_info,
|
||||||
|
graph=json.dumps(graph) if graph else None,
|
||||||
|
input_fields=json.dumps(input_fields) if input_fields else None,
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(snippet)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return snippet
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_snippet(
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
account_id: str,
|
||||||
|
data: dict,
|
||||||
|
) -> CustomizedSnippet:
|
||||||
|
"""
|
||||||
|
Update snippet attributes.
|
||||||
|
|
||||||
|
:param session: Database session
|
||||||
|
:param snippet: Snippet to update
|
||||||
|
:param account_id: ID of account making the update
|
||||||
|
:param data: Dictionary of fields to update
|
||||||
|
:return: Updated CustomizedSnippet
|
||||||
|
"""
|
||||||
|
if "name" in data:
|
||||||
|
# Check if new name already exists for this tenant
|
||||||
|
existing = (
|
||||||
|
session.query(CustomizedSnippet)
|
||||||
|
.where(
|
||||||
|
CustomizedSnippet.tenant_id == snippet.tenant_id,
|
||||||
|
CustomizedSnippet.name == data["name"],
|
||||||
|
CustomizedSnippet.id != snippet.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"Snippet with name '{data['name']}' already exists")
|
||||||
|
snippet.name = data["name"]
|
||||||
|
|
||||||
|
if "description" in data:
|
||||||
|
snippet.description = data["description"]
|
||||||
|
|
||||||
|
if "icon_info" in data:
|
||||||
|
snippet.icon_info = data["icon_info"]
|
||||||
|
|
||||||
|
snippet.updated_by = account_id
|
||||||
|
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
session.add(snippet)
|
||||||
|
return snippet
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_snippet(
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a snippet.
|
||||||
|
|
||||||
|
:param session: Database session
|
||||||
|
:param snippet: Snippet to delete
|
||||||
|
:return: True if deleted successfully
|
||||||
|
"""
|
||||||
|
session.delete(snippet)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# --- Workflow Operations ---
|
||||||
|
|
||||||
|
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
|
||||||
|
"""
|
||||||
|
Get draft workflow for snippet.
|
||||||
|
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:return: Draft Workflow or None
|
||||||
|
"""
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.where(
|
||||||
|
Workflow.tenant_id == snippet.tenant_id,
|
||||||
|
Workflow.app_id == snippet.id,
|
||||||
|
Workflow.type == WorkflowType.SNIPPET.value,
|
||||||
|
Workflow.version == "draft",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
|
||||||
|
"""
|
||||||
|
Get published workflow for snippet.
|
||||||
|
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:return: Published Workflow or None
|
||||||
|
"""
|
||||||
|
if not snippet.workflow_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.where(
|
||||||
|
Workflow.tenant_id == snippet.tenant_id,
|
||||||
|
Workflow.app_id == snippet.id,
|
||||||
|
Workflow.type == WorkflowType.SNIPPET.value,
|
||||||
|
Workflow.id == snippet.workflow_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def sync_draft_workflow(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
graph: dict,
|
||||||
|
unique_hash: str | None,
|
||||||
|
account: Account,
|
||||||
|
environment_variables: Sequence[VariableBase],
|
||||||
|
conversation_variables: Sequence[VariableBase],
|
||||||
|
input_variables: list[dict] | None = None,
|
||||||
|
) -> Workflow:
|
||||||
|
"""
|
||||||
|
Sync draft workflow for snippet.
|
||||||
|
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:param graph: Workflow graph configuration
|
||||||
|
:param unique_hash: Hash for conflict detection
|
||||||
|
:param account: Account making the change
|
||||||
|
:param environment_variables: Environment variables
|
||||||
|
:param conversation_variables: Conversation variables
|
||||||
|
:param input_variables: Input variables for snippet
|
||||||
|
:return: Synced Workflow
|
||||||
|
:raises WorkflowHashNotEqualError: If hash mismatch
|
||||||
|
"""
|
||||||
|
workflow = self.get_draft_workflow(snippet=snippet)
|
||||||
|
|
||||||
|
if workflow and workflow.unique_hash != unique_hash:
|
||||||
|
raise WorkflowHashNotEqualError()
|
||||||
|
|
||||||
|
# Create draft workflow if not found
|
||||||
|
if not workflow:
|
||||||
|
workflow = Workflow(
|
||||||
|
tenant_id=snippet.tenant_id,
|
||||||
|
app_id=snippet.id,
|
||||||
|
features="{}",
|
||||||
|
type=WorkflowType.SNIPPET.value,
|
||||||
|
version="draft",
|
||||||
|
graph=json.dumps(graph),
|
||||||
|
created_by=account.id,
|
||||||
|
environment_variables=environment_variables,
|
||||||
|
conversation_variables=conversation_variables,
|
||||||
|
)
|
||||||
|
db.session.add(workflow)
|
||||||
|
db.session.flush()
|
||||||
|
else:
|
||||||
|
# Update existing draft workflow
|
||||||
|
workflow.graph = json.dumps(graph)
|
||||||
|
workflow.updated_by = account.id
|
||||||
|
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
workflow.environment_variables = environment_variables
|
||||||
|
workflow.conversation_variables = conversation_variables
|
||||||
|
|
||||||
|
# Update snippet's input_fields if provided
|
||||||
|
if input_variables is not None:
|
||||||
|
snippet.input_fields = json.dumps(input_variables)
|
||||||
|
snippet.updated_by = account.id
|
||||||
|
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def publish_workflow(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
account: Account,
|
||||||
|
) -> Workflow:
|
||||||
|
"""
|
||||||
|
Publish the draft workflow as a new version.
|
||||||
|
|
||||||
|
:param session: Database session
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:param account: Account making the change
|
||||||
|
:return: Published Workflow
|
||||||
|
:raises ValueError: If no draft workflow exists
|
||||||
|
"""
|
||||||
|
draft_workflow_stmt = select(Workflow).where(
|
||||||
|
Workflow.tenant_id == snippet.tenant_id,
|
||||||
|
Workflow.app_id == snippet.id,
|
||||||
|
Workflow.type == WorkflowType.SNIPPET.value,
|
||||||
|
Workflow.version == "draft",
|
||||||
|
)
|
||||||
|
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||||
|
if not draft_workflow:
|
||||||
|
raise ValueError("No valid workflow found.")
|
||||||
|
|
||||||
|
# Create new published workflow
|
||||||
|
workflow = Workflow.new(
|
||||||
|
tenant_id=snippet.tenant_id,
|
||||||
|
app_id=snippet.id,
|
||||||
|
type=draft_workflow.type,
|
||||||
|
version=str(datetime.now(UTC).replace(tzinfo=None)),
|
||||||
|
graph=draft_workflow.graph,
|
||||||
|
features=draft_workflow.features,
|
||||||
|
created_by=account.id,
|
||||||
|
environment_variables=draft_workflow.environment_variables,
|
||||||
|
conversation_variables=draft_workflow.conversation_variables,
|
||||||
|
marked_name="",
|
||||||
|
marked_comment="",
|
||||||
|
)
|
||||||
|
session.add(workflow)
|
||||||
|
|
||||||
|
# Update snippet version
|
||||||
|
snippet.version += 1
|
||||||
|
snippet.is_published = True
|
||||||
|
snippet.workflow_id = workflow.id
|
||||||
|
snippet.updated_by = account.id
|
||||||
|
session.add(snippet)
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def get_all_published_workflows(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
page: int,
|
||||||
|
limit: int,
|
||||||
|
) -> tuple[Sequence[Workflow], bool]:
|
||||||
|
"""
|
||||||
|
Get all published workflow versions for snippet.
|
||||||
|
|
||||||
|
:param session: Database session
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:param page: Page number
|
||||||
|
:param limit: Items per page
|
||||||
|
:return: Tuple of (workflows list, has_more flag)
|
||||||
|
"""
|
||||||
|
if not snippet.workflow_id:
|
||||||
|
return [], False
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(Workflow)
|
||||||
|
.where(
|
||||||
|
Workflow.app_id == snippet.id,
|
||||||
|
Workflow.type == WorkflowType.SNIPPET.value,
|
||||||
|
Workflow.version != "draft",
|
||||||
|
)
|
||||||
|
.order_by(Workflow.version.desc())
|
||||||
|
.limit(limit + 1)
|
||||||
|
.offset((page - 1) * limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
workflows = list(session.scalars(stmt).all())
|
||||||
|
has_more = len(workflows) > limit
|
||||||
|
if has_more:
|
||||||
|
workflows = workflows[:-1]
|
||||||
|
|
||||||
|
return workflows, has_more
|
||||||
|
|
||||||
|
# --- Default Block Configs ---
|
||||||
|
|
||||||
|
def get_default_block_configs(self) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get default block configurations for all node types.
|
||||||
|
|
||||||
|
:return: List of default configurations
|
||||||
|
"""
|
||||||
|
default_block_configs: list[dict[str, Any]] = []
|
||||||
|
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||||
|
node_class = node_class_mapping[LATEST_VERSION]
|
||||||
|
default_config = node_class.get_default_config()
|
||||||
|
if default_config:
|
||||||
|
default_block_configs.append(dict(default_config))
|
||||||
|
|
||||||
|
return default_block_configs
|
||||||
|
|
||||||
|
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
|
||||||
|
"""
|
||||||
|
Get default config for specific node type.
|
||||||
|
|
||||||
|
:param node_type: Node type string
|
||||||
|
:param filters: Optional filters
|
||||||
|
:return: Default configuration or None
|
||||||
|
"""
|
||||||
|
node_type_enum = NodeType(node_type)
|
||||||
|
|
||||||
|
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||||
|
return None
|
||||||
|
|
||||||
|
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||||
|
default_config = node_class.get_default_config(filters=filters)
|
||||||
|
if not default_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return default_config
|
||||||
|
|
||||||
|
# --- Workflow Run Operations ---
|
||||||
|
|
||||||
|
def get_snippet_workflow_runs(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
args: dict,
|
||||||
|
) -> InfiniteScrollPagination:
|
||||||
|
"""
|
||||||
|
Get paginated workflow runs for snippet.
|
||||||
|
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:param args: Request arguments (last_id, limit)
|
||||||
|
:return: InfiniteScrollPagination result
|
||||||
|
"""
|
||||||
|
limit = int(args.get("limit", 20))
|
||||||
|
last_id = args.get("last_id")
|
||||||
|
|
||||||
|
triggered_from_values = [
|
||||||
|
WorkflowRunTriggeredFrom.DEBUGGING,
|
||||||
|
]
|
||||||
|
|
||||||
|
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||||
|
tenant_id=snippet.tenant_id,
|
||||||
|
app_id=snippet.id,
|
||||||
|
triggered_from=triggered_from_values,
|
||||||
|
limit=limit,
|
||||||
|
last_id=last_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_snippet_workflow_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
run_id: str,
|
||||||
|
) -> WorkflowRun | None:
|
||||||
|
"""
|
||||||
|
Get workflow run details.
|
||||||
|
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:param run_id: Workflow run ID
|
||||||
|
:return: WorkflowRun or None
|
||||||
|
"""
|
||||||
|
return self._workflow_run_repo.get_workflow_run_by_id(
|
||||||
|
tenant_id=snippet.tenant_id,
|
||||||
|
app_id=snippet.id,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_snippet_workflow_run_node_executions(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
run_id: str,
|
||||||
|
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||||
|
"""
|
||||||
|
Get workflow run node execution list.
|
||||||
|
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
:param run_id: Workflow run ID
|
||||||
|
:return: List of WorkflowNodeExecutionModel
|
||||||
|
"""
|
||||||
|
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||||
|
if not workflow_run:
|
||||||
|
return []
|
||||||
|
|
||||||
|
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
|
||||||
|
tenant_id=snippet.tenant_id,
|
||||||
|
app_id=snippet.id,
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return node_executions
|
||||||
|
|
||||||
|
# --- Use Count ---
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def increment_use_count(
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
snippet: CustomizedSnippet,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Increment the use_count when snippet is used.
|
||||||
|
|
||||||
|
:param session: Database session
|
||||||
|
:param snippet: CustomizedSnippet instance
|
||||||
|
"""
|
||||||
|
snippet.use_count += 1
|
||||||
|
session.add(snippet)
|
||||||
Loading…
Reference in New Issue
Block a user