feat: evaluation (#35251)

Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Yansong Zhang <916125788@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: hj24 <huangjian@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
FFXN 2026-04-15 16:09:40 +08:00 committed by GitHub
parent 12b1cc3d2e
commit f9b76f0f52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
94 changed files with 12010 additions and 335 deletions

View File

@ -1366,6 +1366,32 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
)
class EvaluationConfig(BaseSettings):
"""
Configuration for evaluation runtime
"""
EVALUATION_FRAMEWORK: str = Field(
description="Evaluation framework to use (ragas/deepeval/none)",
default="none",
)
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
description="Maximum number of concurrent evaluation runs per tenant",
default=3,
)
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
description="Maximum number of rows allowed in an evaluation dataset",
default=500,
)
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for a single evaluation task",
default=3600,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1378,6 +1404,7 @@ class FeatureConfig(
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
EvaluationConfig,
FileAccessConfig,
FileUploadConfig,
HttpConfig,

View File

@ -107,6 +107,9 @@ from .datasets.rag_pipeline import (
rag_pipeline_workflow,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import explore controllers
from .explore import (
banner,
@ -117,6 +120,9 @@ from .explore import (
trial,
)
# Import snippet controllers
from .snippets import snippet_workflow, snippet_workflow_draft_variable
# Import tag controllers
from .tag import tags
@ -130,6 +136,7 @@ from .workspace import (
model_providers,
models,
plugin,
snippets,
tool_providers,
trigger_providers,
workspace,
@ -167,6 +174,7 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@ -201,6 +209,9 @@ __all__ = [
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippet_workflow_draft_variable",
"snippets",
"spec",
"statistic",
"tags",

View File

@ -329,6 +329,7 @@ class AppPartial(ResponseModel):
create_user_name: str | None = None
author_name: str | None = None
has_draft_trigger: bool | None = None
workflow_type: str | None = None
@computed_field(return_type=str | None) # type: ignore
@property
@ -363,6 +364,7 @@ class AppDetail(ResponseModel):
updated_by: str | None = None
updated_at: int | None = None
access_mode: str | None = None
workflow_type: str | None = None
tags: list[Tag] = Field(default_factory=list)
@field_validator("created_at", "updated_at", mode="before")
@ -505,6 +507,17 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
workflow_type_map: dict[str, str] = {}
if workflow_ids:
rows = db.session.execute(
select(Workflow.id, Workflow.type).where(Workflow.id.in_(workflow_ids))
).all()
workflow_type_map = {str(row.id): row.type for row in rows}
for app in app_pagination.items:
app.workflow_type = workflow_type_map.get(str(app.workflow_id)) if app.workflow_id else None
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
return pagination_model.model_dump(mode="json"), 200
@ -551,6 +564,14 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
if app_model.workflow_id:
row = db.session.execute(
select(Workflow.type).where(Workflow.id == app_model.workflow_id)
).scalar()
app_model.workflow_type = row if row else None
else:
app_model.workflow_type = None
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")

View File

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Sequence
from typing import Any
from typing import Any, Literal
from flask import abort, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -46,7 +46,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowType
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -150,6 +150,24 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
keyword: str | None = Field(default=None, max_length=255)
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class WorkflowTypeConvertQuery(BaseModel):
target_type: Literal["workflow", "evaluation"]
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
@ -173,6 +191,7 @@ reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowTypeConvertQuery)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
@ -845,6 +864,54 @@ class PublishedWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
class EvaluationPublishedWorkflowApi(Resource):
@console_ns.doc("publish_evaluation_workflow")
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@console_ns.response(200, "Evaluation workflow published successfully")
@console_ns.response(400, "Invalid workflow or unsupported node type")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Publish draft workflow as evaluation workflow.
Evaluation workflows cannot include trigger or human-input nodes.
"""
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflow = workflow_service.publish_evaluation_workflow(
session=session,
app_model=app_model,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)
# Keep workflow_id aligned with the latest published workflow.
app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
class DefaultBlockConfigsApi(Resource):
@console_ns.doc("get_default_block_configs")
@ -1016,6 +1083,51 @@ class DraftWorkflowRestoreApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
class WorkflowTypeConvertApi(Resource):
@console_ns.doc("convert_published_workflow_type")
@console_ns.doc(description="Convert current effective published workflow type in-place")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
@console_ns.response(200, "Workflow type converted successfully")
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
target_type = WorkflowType.value_of(args.target_type)
workflow_service = WorkflowService()
with Session(db.engine) as session:
try:
workflow = workflow_service.convert_published_workflow_type(
session=session,
app_model=app_model,
target_type=target_type,
account=current_user,
)
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except IsDraftWorkflowError as exc:
raise BadRequest(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
session.commit()
return {
"result": "success",
"workflow_id": workflow.id,
"type": workflow.type.value,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")

View File

@ -1,4 +1,6 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@ -10,6 +12,7 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -77,3 +80,39 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@ -1,11 +1,14 @@
import json
from typing import Any, cast
from urllib.parse import quote
from flask import request
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import services
from configs import dify_config
@ -22,6 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from core.indexing_runner import IndexingRunner
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.datasource.vdb.vector_type import VectorType
@ -30,6 +34,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
@ -50,12 +55,19 @@ from fields.dataset_fields import (
)
from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
@ -983,3 +995,432 @@ class DatasetAutoDisableLogApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
# ---- Knowledge Base Retrieval Evaluation ----
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
class DatasetEvaluationTemplateDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Download evaluation dataset template for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
class DatasetEvaluationDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(
session, current_tenant_id, "dataset", dataset_id_str
)
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"default_metrics": None,
"customized_metrics": None,
"judgment_config": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.doc("save_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration saved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id):
"""Save evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
account_id=str(current_user.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
class DatasetEvaluationRunApi(Resource):
@console_ns.doc("start_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Start an evaluation run for the knowledge base retrieval."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
upload_file = (
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
target_id=dataset_id_str,
account_id=str(current_user.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_dataset_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
class DatasetEvaluationLogsApi(Resource):
@console_ns.doc("get_dataset_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation run history for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
page=page,
page_size=page_size,
)
return {
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
class DatasetEvaluationRunDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, run_id):
"""Get evaluation run detail including per-item results."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id_str,
page=page,
page_size=page_size,
)
return {
"run": _serialize_dataset_evaluation_run(run),
"items": {
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
class DatasetEvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, run_id):
"""Cancel a running knowledge base evaluation."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
return _serialize_dataset_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
class DatasetEvaluationMetricsApi(Resource):
@console_ns.doc("get_dataset_evaluation_metrics")
@console_ns.response(200, "Available retrieval metrics retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get available evaluation metrics for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return {
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.KNOWLEDGE_BASE)
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
class DatasetEvaluationFileDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_file")
@console_ns.response(200, "File download URL generated")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or file not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, file_id):
"""Download evaluation test file or result file for the knowledge base."""
from core.workflow.file import helpers as file_helpers
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
file_id_str = str(file_id)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id_str,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found.")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}

View File

@ -0,0 +1 @@
# Evaluation controller module

View File

@ -0,0 +1,869 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.workflow import WorkflowListQuery
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.member_fields import simple_account_fields
from graphon.file import helpers as file_helpers
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Dataset
from models.model import UploadFile
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
from services.workflow_service import WorkflowService
if TYPE_CHECKING:
from models.evaluation import EvaluationRun, EvaluationRunItem
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Valid evaluation target types
EVALUATE_TARGET_TYPES = {"app", "snippets"}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
evaluation_default_metric_node_info_fields = {
"node_id": fields.String,
"type": fields.String,
"title": fields.String,
}
evaluation_default_metric_item_fields = {
"metric": fields.String,
"value_type": fields.String,
"node_info_list": fields.List(
fields.Nested(
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
),
),
}
customized_metrics_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
judgment_condition_fields = {
"variable_selector": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgment_config_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"default_metrics": fields.List(
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
allow_null=True,
),
"customized_metrics": fields.Nested(
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
allow_null=True,
),
"judgment_config": fields.Nested(
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
available_evaluation_workflow_list_fields = {
"id": fields.String,
"app_id": fields.String,
"app_name": fields.String,
"type": fields.String,
"version": fields.String,
"marked_name": fields.String,
"marked_comment": fields.String,
"hash": fields.String,
"created_by": fields.Nested(simple_account_fields),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
"updated_at": TimestampField,
}
available_evaluation_workflow_pagination_fields = {
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"has_more": fields.Boolean,
}
available_evaluation_workflow_pagination_model = console_ns.model(
"AvailableEvaluationWorkflowPagination",
available_evaluation_workflow_pagination_fields,
)
evaluation_default_metrics_response_model = console_ns.model(
"EvaluationDefaultMetricsResponse",
{
"default_metrics": fields.List(
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
),
},
)
def get_evaluation_target(view_func: Callable[P, R]):
"""
Decorator to resolve polymorphic evaluation target (app or snippet).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet, Dataset] | None = None
if target_type == "app":
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
elif target_type == "snippets":
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
elif target_type == "knowledge":
target = (db.session.query(Dataset)
.where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id)
.first())
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(400, "Invalid target type or excluded app mode")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates an XLSX template based on the target's input parameters
and streams it directly as a file attachment.
"""
try:
xlsx_content, filename = EvaluationService.generate_dataset_template(
target=target,
target_type=target_type,
)
except ValueError as e:
return {"message": str(e)}, 400
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation configuration for the target.
Returns evaluation configuration including model settings,
metrics config, and judgement conditions.
"""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"default_metrics": None,
"customized_metrics": None,
"judgment_config": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.doc("save_evaluation_detail")
@console_ns.response(200, "Evaluation configuration saved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Save evaluation configuration for the target.
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation run history for the target.
Returns a paginated list of evaluation runs.
"""
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
page=page,
page_size=page_size,
)
return {
"data": [_serialize_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
class EvaluationRunApi(Resource):
@console_ns.doc("start_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
"""
Start an evaluation run.
Expects JSON body with:
- file_id: uploaded dataset file ID
- evaluation_model: evaluation model name
- evaluation_model_provider: evaluation model provider
- default_metrics: list of default metric objects
- customized_metrics: customized metrics object (optional)
- judgment_config: judgment conditions config (optional)
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
# Validate and parse request body
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
# Load dataset file
upload_file = (
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
class EvaluationRunDetailApi(Resource):
@console_ns.doc("get_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""
Get evaluation run detail including items.
"""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id,
page=page,
page_size=page_size,
)
return {
"run": _serialize_evaluation_run(run),
"items": {
"data": [_serialize_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
class EvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""Cancel a running evaluation."""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
return _serialize_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
class EvaluationMetricsApi(Resource):
@console_ns.doc("get_evaluation_metrics")
@console_ns.response(200, "Available metrics retrieved")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get available evaluation metrics for the current framework.
"""
result = {}
for category in EvaluationCategory:
result[category.value] = EvaluationService.get_supported_metrics(category)
return {"metrics": result}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
class EvaluationDefaultMetricsApi(Resource):
@console_ns.doc(
"get_evaluation_default_metrics_with_nodes",
description=(
"List default metrics supported by the current evaluation framework with matching nodes "
"from the target's published workflow only (draft is ignored)."
),
)
@console_ns.response(
200,
"Default metrics and node candidates for the published workflow",
evaluation_default_metrics_response_model,
)
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
target=target,
target_type=target_type,
)
return {"default_metrics": [m.model_dump() for m in default_metrics]}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
class EvaluationNodeInfoApi(Resource):
@console_ns.doc("get_evaluation_node_info")
@console_ns.response(200, "Node info grouped by metric")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""Return workflow/snippet node info grouped by requested metrics.
Request body (JSON):
- metrics: list[str] | None metric names to query; omit or pass
an empty list to get all nodes under key ``"all"``.
Response:
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
"""
body = request.get_json(silent=True) or {}
metrics: list[str] | None = body.get("metrics") or None
result = EvaluationService.get_nodes_for_metrics(
target=target,
target_type=target_type,
metrics=metrics,
)
return result
@console_ns.route("/evaluation/available-metrics")
class EvaluationAvailableMetricsApi(Resource):
@console_ns.doc("get_available_evaluation_metrics")
@console_ns.response(200, "Available metrics list")
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Return the centrally-defined list of evaluation metrics."""
return {"metrics": EvaluationService.get_available_metrics()}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Looks up the specified file, verifies it belongs to the same tenant,
and returns file info and download URL.
"""
file_id = str(file_id)
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
graph = {}
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}
@console_ns.route("/workspaces/current/available-evaluation-workflows")
class AvailableEvaluationWorkflowsApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.doc("list_available_evaluation_workflows")
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
@console_ns.response(
200,
"Available evaluation workflows retrieved",
available_evaluation_workflow_pagination_model,
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self):
"""List published evaluation-type workflows for the current tenant (cross-app)."""
current_user, current_tenant_id = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
user_id = args.user_id
named_only = args.named_only
keyword = args.keyword
if user_id and user_id != current_user.id:
raise Forbidden()
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflows, has_more = workflow_service.list_published_evaluation_workflows(
session=session,
tenant_id=current_tenant_id,
page=page,
limit=limit,
user_id=user_id,
named_only=named_only,
keyword=keyword,
)
app_ids = {w.app_id for w in workflows}
if app_ids:
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
app_names = {a.id: a.name for a in apps}
else:
app_names = {}
items = []
for wf in workflows:
items.append(
{
"id": wf.id,
"app_id": wf.app_id,
"app_name": app_names.get(wf.app_id, ""),
"type": wf.type.value,
"version": wf.version,
"marked_name": wf.marked_name,
"marked_comment": wf.marked_comment,
"hash": wf.unique_hash,
"created_by": wf.created_by_account,
"created_at": wf.created_at,
"updated_by": wf.updated_by_account,
"updated_at": wf.updated_at,
}
)
return (
marshal(
{"items": items, "page": page, "limit": limit, "has_more": has_more},
available_evaluation_workflow_pagination_fields,
),
200,
)
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
class EvaluationWorkflowAssociatedTargetsApi(Resource):
@console_ns.doc("list_evaluation_workflow_associated_targets")
@console_ns.doc(
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, workflow_id: str):
"""Return all evaluation targets that reference this workflow as customized metrics."""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
configs = EvaluationService.list_targets_by_customized_workflow(
session=session,
tenant_id=current_tenant_id,
customized_workflow_id=workflow_id,
)
target_ids_by_type: dict[str, list[str]] = {}
for cfg in configs:
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
app_names: dict[str, str] = {}
if "app" in target_ids_by_type:
apps = session.scalars(select(App).where(App.id.in_(target_ids_by_type["app"]))).all()
app_names = {a.id: a.name for a in apps}
snippet_names: dict[str, str] = {}
if "snippets" in target_ids_by_type:
snippets = session.scalars(
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
).all()
snippet_names = {s.id: s.name for s in snippets}
dataset_names: dict[str, str] = {}
if "knowledge_base" in target_ids_by_type:
datasets = session.scalars(
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
).all()
dataset_names = {d.id: d.name for d in datasets}
items = []
for cfg in configs:
name = ""
if cfg.target_type == "app":
name = app_names.get(cfg.target_id, "")
elif cfg.target_type == "snippets":
name = snippet_names.get(cfg.target_id, "")
elif cfg.target_type == "knowledge_base":
name = dataset_names.get(cfg.target_id, "")
items.append(
{
"target_type": cfg.target_type,
"target_id": cfg.target_id,
"target_name": name,
}
)
return {"items": items}, 200
# ---- Serialization Helpers ----
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": run.metrics_summary_dict,
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}

View File

@ -0,0 +1,135 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
is_published: bool | None = Field(default=None, description="Filter by published status")
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
@field_validator("creators", mode="before")
@classmethod
def parse_creators(cls, value: object) -> list[str] | None:
"""Normalize creators filter from query string or list input."""
if value is None:
return None
if isinstance(value, str):
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
if isinstance(value, list):
return [str(creator).strip() for creator in value if str(creator).strip()] or None
return None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
conversation_variables: list[dict[str, Any]] | None = Field(
default=None,
description="Ignored. Snippet workflows do not persist conversation variables.",
)
input_fields: list[dict[str, Any]] | None = None
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None
class SnippetImportPayload(BaseModel):
"""Payload for importing snippet from DSL."""
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
name: str | None = Field(default=None, description="Override snippet name")
description: str | None = Field(default=None, description="Override snippet description")
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
class IncludeSecretQuery(BaseModel):
"""Query parameter for including secret variables in export."""
include_secret: str = Field(default="false", description="Whether to include secret variables")

View File

@ -0,0 +1,534 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import workflow_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetDraftSyncPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import WorkflowHashNotEqualError
from services.snippet_generate_service import SnippetGenerateService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
PublishWorkflowPayload,
)
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet(view_func: Callable[P, R]):
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
db.session.expunge(workflow)
workflow.conversation_variables = []
return workflow
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
input_fields=payload.input_fields,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
return workflow
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_pagination_model)
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return result
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_detail_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return workflow_run
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_list_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return {"data": node_executions}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
class SnippetDraftNodeRunApi(Resource):
@console_ns.doc("run_snippet_draft_node")
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
# Get draft workflow for file parsing
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
workflow_node_execution = SnippetGenerateService.run_draft_node(
snippet=snippet,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=payload.query,
files=files,
)
return workflow_node_execution
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
class SnippetDraftNodeLastRunApi(Resource):
@console_ns.doc("get_snippet_draft_node_last_run")
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
def get(self, snippet: CustomizedSnippet, node_id: str):
"""
Get the last run result for a specific node in snippet draft workflow.
Returns the most recent execution record for the given node,
including status, inputs, outputs, and timing information.
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
node_exec = snippet_service.get_snippet_node_last_run(
snippet=snippet,
workflow=draft_workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("Node last run not found")
return node_exec
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class SnippetDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_snippet_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate_single_iteration(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class SnippetDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_snippet_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = SnippetGenerateService.generate_single_loop(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
class SnippetDraftWorkflowRunApi(Resource):
@console_ns.doc("run_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate(
snippet=snippet,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
class SnippetWorkflowTaskStopApi(Resource):
@console_ns.doc("stop_snippet_workflow_task")
@console_ns.response(200, "Task stopped successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, task_id: str):
"""
Stop a running snippet workflow task.
Uses both the legacy stop flag mechanism and the graph engine
command channel for backward compatibility.
"""
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -0,0 +1,319 @@
"""
Snippet draft workflow variable APIs.
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
"""
from collections.abc import Callable
from functools import wraps
from typing import Any, ParamSpec, TypeVar
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
WorkflowDraftVariableListQuery,
WorkflowDraftVariableUpdatePayload,
_ensure_variable_access,
_file_access_controller,
validate_node_id,
workflow_draft_variable_list_model,
workflow_draft_variable_list_without_value_model,
workflow_draft_variable_model,
)
from controllers.console.snippets.snippet_workflow import get_snippet
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from graphon.variables.types import SegmentType
from libs.login import current_user, login_required
from models.snippet import CustomizedSnippet
from models.workflow import WorkflowDraftVariable
from services.snippet_service import SnippetService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
P = ParamSpec("P")
R = TypeVar("R")
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
)
def _ensure_snippet_draft_variable_row_allowed(
*,
variable: WorkflowDraftVariable,
variable_id: str,
) -> None:
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
raise NotFoundError(description=f"variable not found, id={variable_id}")
def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]:
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return f(*args, **kwargs)
return wrapper
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
class SnippetWorkflowVariableCollectionApi(Resource):
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_snippet_workflow_variables")
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
@console_ns.response(
200,
"Workflow variables retrieved successfully",
workflow_draft_variable_list_without_value_model,
)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
snippet_service = SnippetService()
if snippet_service.get_draft_workflow(snippet=snippet) is None:
raise DraftWorkflowNotExist()
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(session=session)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=snippet.id,
page=args.page,
limit=args.limit,
user_id=current_user.id,
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
)
return workflow_vars
@console_ns.doc("delete_snippet_workflow_variables")
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
@console_ns.response(204, "Workflow variables deleted successfully")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet) -> Response:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
class SnippetNodeVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_node_variables")
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(session=session)
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
return node_vars
@console_ns.doc("delete_snippet_node_variables")
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
@console_ns.response(204, "Node variables deleted successfully")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
class SnippetVariableApi(Resource):
@console_ns.doc("get_snippet_workflow_variable")
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
return variable
@console_ns.doc("update_snippet_workflow_variable")
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
new_name = args_model.name
raw_value = args_model.value
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=snippet.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=snippet.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@console_ns.doc("delete_snippet_workflow_variable")
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class SnippetVariableResetApi(Resource):
@console_ns.doc("reset_snippet_workflow_variable")
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, snippet_id={snippet.id}",
)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
return marshal(resetted, workflow_draft_variable_model)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
class SnippetConversationVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_conversation_variables")
@console_ns.doc(
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
)
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
return WorkflowDraftVariableList(variables=[])
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
class SnippetSystemVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_system_variables")
@console_ns.doc(
description="System variables are not used in snippet workflows; returns an empty list for API parity"
)
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
return WorkflowDraftVariableList(variables=[])
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
class SnippetEnvironmentVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_environment_variables")
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
@console_ns.response(200, "Environment variables retrieved successfully")
@console_ns.response(404, "Draft workflow not found")
@_snippet_draft_var_prerequisite
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars_list: list[dict[str, Any]] = []
for v in workflow.environment_variables:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.exposed_type().value,
"value": v.value,
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}

View File

@ -0,0 +1,380 @@
import logging
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, marshal
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
IncludeSecretQuery,
SnippetImportPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.app_dsl_service import ImportStatus
from services.snippet_dsl_service import SnippetDslService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
SnippetImportPayload,
IncludeSecretQuery,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
is_published=query.is_published,
creators=query.creators,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
class CustomizedSnippetExportApi(Resource):
@console_ns.doc("export_customized_snippet")
@console_ns.doc(description="Export snippet configuration as DSL")
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
@console_ns.response(200, "Snippet exported successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Export snippet as DSL."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
# Get include_secret parameter
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = SnippetDslService(session)
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
# Set filename with .snippet extension
filename = f"{snippet.name}.snippet"
encoded_filename = quote(filename)
response = Response(
result,
mimetype="application/x-yaml",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/x-yaml"
return response
@console_ns.route("/workspaces/current/customized-snippets/imports")
class CustomizedSnippetImportApi(Resource):
@console_ns.doc("import_customized_snippet")
@console_ns.doc(description="Import snippet from DSL")
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
@console_ns.response(200, "Snippet imported successfully")
@console_ns.response(202, "Import pending confirmation")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Import snippet from DSL."""
current_user, _ = current_account_with_tenant()
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.import_snippet(
account=current_user,
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
snippet_id=payload.snippet_id,
name=payload.name,
description=payload.description,
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
class CustomizedSnippetImportConfirmApi(Resource):
@console_ns.doc("confirm_snippet_import")
@console_ns.doc(description="Confirm a pending snippet import")
@console_ns.doc(params={"import_id": "Import ID to confirm"})
@console_ns.response(200, "Import confirmed successfully")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
"""Confirm a pending snippet import."""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.confirm_import(import_id=import_id, account=current_user)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
class CustomizedSnippetCheckDependenciesApi(Resource):
@console_ns.doc("check_snippet_dependencies")
@console_ns.doc(description="Check dependencies for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Dependencies checked successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Check dependencies for a snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.check_dependencies(snippet=snippet)
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
class CustomizedSnippetUseCountIncrementApi(Resource):
@console_ns.doc("increment_snippet_use_count")
@console_ns.doc(description="Increment snippet use count by 1")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Use count incremented successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, snippet_id: str):
"""Increment snippet use count when it is inserted into a workflow."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.increment_use_count(session=session, snippet=snippet)
session.commit()
session.refresh(snippet)
return {"result": "success", "use_count": snippet.use_count}, 200

View File

@ -14,7 +14,7 @@ from graphon.runtime import GraphRuntimeState
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
@ -54,6 +54,25 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
if workflow.type != "snippet":
return workflow
from models.snippet import CustomizedSnippet
from services.snippet_generate_service import SnippetGenerateService
snippet = session.scalar(
select(CustomizedSnippet).where(
CustomizedSnippet.id == workflow.app_id,
CustomizedSnippet.tenant_id == workflow.tenant_id,
)
)
if snippet is None:
return workflow
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@ -557,6 +576,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
if workflow is None:
raise ValueError("Workflow not found")
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,

View File

View File

@ -0,0 +1,279 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
NodeInfo,
)
from graphon.node_events.base import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationInstance(ABC):
"""Abstract base class for evaluation framework adapters."""
@abstractmethod
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate LLM outputs using the configured framework."""
...
@abstractmethod
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate retrieval quality using the configured framework."""
...
@abstractmethod
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate agent outputs using the configured framework."""
...
@abstractmethod
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
"""Return the list of supported metric names for a given evaluation category."""
...
def evaluate_with_customized_workflow(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
customized_metrics: CustomizedMetrics,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate using a published workflow as the evaluator.
The evaluator workflow's output variables are treated as metrics:
each output variable name becomes a metric name, and its value
becomes the score.
Args:
node_run_result_mapping_list: One mapping per test-data item,
where each mapping is ``{node_id: NodeRunResult}`` from the
target execution.
customized_metrics: Contains ``evaluation_workflow_id`` (the
published evaluator workflow) and ``input_fields`` (value
sources for the evaluator's input variables).
tenant_id: Tenant scope.
Returns:
A list of ``EvaluationItemResult`` with metrics extracted from
the evaluator workflow's output variables.
"""
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app
from models.engine import db
from models.model import App
from services.workflow_service import WorkflowService
workflow_id = customized_metrics.evaluation_workflow_id
if not workflow_id:
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
# Load the evaluator workflow resources using a dedicated session
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
if not app:
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
service_account = get_service_account_for_app(session, workflow_id)
workflow_service = WorkflowService()
published_workflow = workflow_service.get_published_workflow(app_model=app)
if not published_workflow:
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
eval_results: list[EvaluationItemResult] = []
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
try:
workflow_inputs = self._build_workflow_inputs(
customized_metrics.input_fields,
node_run_result_mapping,
)
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=published_workflow,
user=service_account,
args={"inputs": workflow_inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
)
metrics = self._extract_workflow_metrics(response, workflow_id)
eval_results.append(
EvaluationItemResult(
index=idx,
metrics=metrics,
)
)
except Exception:
logger.exception(
"Customized evaluator failed for item %d with workflow %s",
idx,
workflow_id,
)
eval_results.append(EvaluationItemResult(index=idx))
return eval_results
@staticmethod
def _build_workflow_inputs(
input_fields: dict[str, Any],
node_run_result_mapping: dict[str, NodeRunResult],
) -> dict[str, Any]:
"""Build customized workflow inputs by resolving value sources.
Each entry in ``input_fields`` maps a workflow input variable name
to its value source, which can be:
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
- **Expression**: a string containing one or more
``{{#node_id.output_key#}}`` selectors (same format as
``VariableTemplateParser``) resolved from
``node_run_result_mapping``.
"""
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
workflow_inputs: dict[str, Any] = {}
for field_name, value_source in input_fields.items():
if not isinstance(value_source, str):
# Non-string values (numbers, bools, dicts) are used directly.
workflow_inputs[field_name] = value_source
continue
# Check if the entire value is a single expression.
full_match = VARIABLE_REGEX.fullmatch(value_source)
if full_match:
workflow_inputs[field_name] = resolve_variable_selector(
full_match.group(1),
node_run_result_mapping,
)
elif VARIABLE_REGEX.search(value_source):
# Mixed template: interpolate all expressions as strings.
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
value_source,
)
else:
# Plain constant — no expression markers.
workflow_inputs[field_name] = value_source
return workflow_inputs
@staticmethod
def _extract_workflow_metrics(
response: Mapping[str, object],
evaluation_workflow_id: str,
) -> list[EvaluationMetric]:
"""Extract evaluation metrics from workflow output variables.
Each metric's ``node_info`` is set with *evaluation_workflow_id* as
the ``node_id``, so that judgment conditions can reference customized
metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``.
"""
metrics: list[EvaluationMetric] = []
node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized")
data = response.get("data")
if not isinstance(data, Mapping):
logger.warning("Unexpected workflow response format: missing 'data' dict")
return metrics
outputs = data.get("outputs")
if not isinstance(outputs, dict):
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
return metrics
for key, raw_value in outputs.items():
if not isinstance(key, str):
continue
metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info))
return metrics
def resolve_variable_selector(
selector_raw: str,
node_run_result_mapping: dict[str, NodeRunResult],
) -> object:
"""
Resolve a ``#node_id.output_key#`` selector against node run results.
"""
#
cleaned = selector_raw.strip("#")
parts = cleaned.split(".")
if len(parts) < 2:
logger.warning(
"Selector '%s' must have at least node_id.output_key",
selector_raw,
)
return ""
node_id = parts[0]
output_path = parts[1:]
node_result = node_run_result_mapping.get(node_id)
if not node_result or not node_result.outputs:
logger.warning(
"Selector '%s': node '%s' not found or has no outputs",
selector_raw,
node_id,
)
return ""
# Traverse the output path to support nested keys.
current: object = node_result.outputs
for key in output_path:
if isinstance(current, Mapping):
next_val = current.get(key)
if next_val is None:
logger.warning(
"Selector '%s': key '%s' not found in node '%s' outputs",
selector_raw,
key,
node_id,
)
return ""
current = next_val
else:
logger.warning(
"Selector '%s': cannot traverse into non-dict value at key '%s'",
selector_raw,
key,
)
return ""
return current if current is not None else ""

View File

View File

@ -0,0 +1,27 @@
from enum import StrEnum
from pydantic import BaseModel
class EvaluationFrameworkEnum(StrEnum):
RAGAS = "ragas"
DEEPEVAL = "deepeval"
NONE = "none"
class BaseEvaluationConfig(BaseModel):
"""Base configuration for evaluation frameworks."""
pass
class RagasConfig(BaseEvaluationConfig):
"""RAGAS-specific configuration."""
pass
class DeepEvalConfig(BaseEvaluationConfig):
"""DeepEval-specific configuration."""
pass

View File

@ -0,0 +1,226 @@
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
class EvaluationCategory(StrEnum):
LLM = "llm"
RETRIEVAL = "knowledge_retrieval"
AGENT = "agent"
WORKFLOW = "workflow"
SNIPPET = "snippet"
KNOWLEDGE_BASE = "knowledge_base"
class EvaluationMetricName(StrEnum):
"""Canonical metric names shared across all evaluation frameworks.
Each framework maps these names to its own internal implementation.
A framework that does not support a given metric should log a warning
and skip it rather than raising an error.
LLM / general text-quality metrics
FAITHFULNESS
Measures whether every claim in the model's response is grounded in
the provided retrieved context. A high score means the answer
contains no hallucinated content each statement can be traced back
to a passage in the context.
Required fields: user_input, response, retrieved_contexts.
ANSWER_RELEVANCY
Measures how well the model's response addresses the user's question.
A high score means the answer stays on-topic; a low score indicates
irrelevant content or a failure to answer the actual question.
Required fields: user_input, response.
ANSWER_CORRECTNESS
Measures the factual accuracy and completeness of the model's answer
relative to a ground-truth reference. It combines semantic similarity
with key-fact coverage, so both meaning and content matter.
Required fields: user_input, response, reference (expected_output).
SEMANTIC_SIMILARITY
Measures the cosine similarity between the model's response and the
reference answer in an embedding space. It evaluates whether the two
texts convey the same meaning, independent of factual correctness.
Required fields: response, reference (expected_output).
Retrieval-quality metrics
CONTEXT_PRECISION
Measures the proportion of retrieved context chunks that are actually
relevant to the question (precision). A high score means the retrieval
pipeline returns little noise.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RECALL
Measures the proportion of ground-truth information that is covered by
the retrieved context chunks (recall). A high score means the retrieval
pipeline does not miss important supporting evidence.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RELEVANCE
Measures how relevant each individual retrieved chunk is to the query.
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
than against a reference answer.
Required fields: user_input, retrieved_contexts.
Agent-quality metrics
TOOL_CORRECTNESS
Measures the correctness of the tool calls made by the agent during
task execution both the choice of tool and the arguments passed.
A high score means the agent's tool-use strategy matches the expected
behavior.
Required fields: actual tool calls vs. expected tool calls.
TASK_COMPLETION
Measures whether the agent ultimately achieves the user's stated goal.
It evaluates the reasoning chain, intermediate steps, and final output
holistically; a high score means the task was fully accomplished.
Required fields: user_input, actual_output.
"""
# LLM / general text-quality metrics
FAITHFULNESS = "faithfulness"
ANSWER_RELEVANCY = "answer_relevancy"
ANSWER_CORRECTNESS = "answer_correctness"
SEMANTIC_SIMILARITY = "semantic_similarity"
# Retrieval-quality metrics
CONTEXT_PRECISION = "context_precision"
CONTEXT_RECALL = "context_recall"
CONTEXT_RELEVANCE = "context_relevance"
# Agent-quality metrics
TOOL_CORRECTNESS = "tool_correctness"
TASK_COMPLETION = "task_completion"
# Per-category canonical metric lists used by get_supported_metrics().
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
]
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
]
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
]
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS,
EvaluationMetricName.ANSWER_RELEVANCY,
EvaluationMetricName.ANSWER_CORRECTNESS,
]
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
**{m.value: "llm" for m in LLM_METRIC_NAMES},
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
}
METRIC_VALUE_TYPE_MAPPING: dict[str, str] = {
EvaluationMetricName.FAITHFULNESS: "number",
EvaluationMetricName.ANSWER_RELEVANCY: "number",
EvaluationMetricName.ANSWER_CORRECTNESS: "number",
EvaluationMetricName.SEMANTIC_SIMILARITY: "number",
EvaluationMetricName.CONTEXT_PRECISION: "number",
EvaluationMetricName.CONTEXT_RECALL: "number",
EvaluationMetricName.CONTEXT_RELEVANCE: "number",
EvaluationMetricName.TOOL_CORRECTNESS: "number",
EvaluationMetricName.TASK_COMPLETION: "number",
}
class NodeInfo(BaseModel):
node_id: str
type: str
title: str
class EvaluationMetric(BaseModel):
name: str
value: Any
details: dict[str, Any] = Field(default_factory=dict)
node_info: NodeInfo | None = None
class EvaluationItemInput(BaseModel):
index: int
inputs: dict[str, Any]
output: str
expected_output: str | None = None
context: list[str] | None = None
class EvaluationDatasetInput(BaseModel):
index: int
inputs: dict[str, Any]
expected_output: str | None = None
class EvaluationItemResult(BaseModel):
index: int
actual_output: str | None = None
metrics: list[EvaluationMetric] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
error: str | None = None
class DefaultMetric(BaseModel):
metric: str
value_type: str = ""
node_info_list: list[NodeInfo]
class CustomizedMetricOutputField(BaseModel):
variable: str
value_type: str
class CustomizedMetrics(BaseModel):
evaluation_workflow_id: str
input_fields: dict[str, Any]
output_fields: list[CustomizedMetricOutputField]
class EvaluationConfigData(BaseModel):
"""Structured data for saving evaluation configuration."""
evaluation_model: str = ""
evaluation_model_provider: str = ""
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
class EvaluationRunRequest(EvaluationConfigData):
"""Request body for starting an evaluation run."""
file_id: str
class EvaluationRunData(BaseModel):
"""Serializable data for Celery task."""
evaluation_run_id: str
tenant_id: str
target_type: str
target_id: str
evaluation_model_provider: str
evaluation_model: str
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
input_list: list[EvaluationDatasetInput]

View File

@ -0,0 +1,96 @@
"""Judgment condition entities for evaluation metric assessment.
Condition structure mirrors the workflow if-else ``Condition`` model from
``graphon.utils.condition.entities``. The left-hand side uses
``variable_selector`` a two-element list ``[node_id, metric_name]`` to
uniquely identify an evaluation metric (different nodes may produce metrics
with the same name).
Operators reuse ``SupportedComparisonOperator`` from the workflow engine so
that type semantics stay consistent across the platform.
Typical usage::
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["node_abc", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
"""
from collections.abc import Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field
from graphon.utils.condition.entities import SupportedComparisonOperator
class JudgmentCondition(BaseModel):
"""A single judgment condition that checks one metric value.
Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand
side being a metric selector instead of a workflow variable selector.
Attributes:
variable_selector: ``[node_id, metric_name]`` identifying the metric.
comparison_operator: Reuses workflow's ``SupportedComparisonOperator``.
value: The comparison target (right side). For unary operators such
as ``empty`` or ``null`` this can be ``None``.
"""
variable_selector: list[str]
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | bool | None = None
class JudgmentConfig(BaseModel):
"""A group of judgment conditions combined with a logical operator.
Attributes:
logical_operator: How to combine condition results "and" requires
all conditions to pass, "or" requires at least one.
conditions: The list of individual conditions to evaluate.
"""
logical_operator: Literal["and", "or"] = "and"
conditions: list[JudgmentCondition] = Field(default_factory=list)
class JudgmentConditionResult(BaseModel):
"""Result of evaluating a single judgment condition.
Attributes:
variable_selector: ``[node_id, metric_name]`` that was checked.
comparison_operator: The operator that was applied.
expected_value: The resolved comparison value.
actual_value: The actual metric value that was evaluated.
passed: Whether this individual condition passed.
error: Error message if the condition evaluation failed.
"""
variable_selector: list[str]
comparison_operator: str
expected_value: Any = None
actual_value: Any = None
passed: bool = False
error: str | None = None
class JudgmentResult(BaseModel):
"""Overall result of evaluating all judgment conditions for one item.
Attributes:
passed: Whether the overall judgment passed (based on logical_operator).
logical_operator: The logical operator used to combine conditions.
condition_results: Detailed result for each individual condition.
"""
passed: bool = False
logical_operator: Literal["and", "or"] = "and"
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)

View File

@ -0,0 +1,61 @@
import collections
import logging
from typing import Any
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
from core.evaluation.entities.evaluation_entity import EvaluationCategory
logger = logging.getLogger(__name__)
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
def __getitem__(self, framework: str) -> dict[str, Any]:
match framework:
case EvaluationFrameworkEnum.RAGAS:
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
return {
"config_class": RagasConfig,
"evaluator_class": RagasEvaluator,
}
case EvaluationFrameworkEnum.DEEPEVAL:
raise NotImplementedError("DeepEval adapter is not yet implemented.")
case _:
raise ValueError(f"Unknown evaluation framework: {framework}")
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
class EvaluationManager:
"""Factory for evaluation instances based on global configuration."""
@staticmethod
def get_evaluation_instance() -> BaseEvaluationInstance | None:
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
framework = dify_config.EVALUATION_FRAMEWORK
if not framework or framework == EvaluationFrameworkEnum.NONE:
return None
try:
config_map = evaluation_framework_config_map[framework]
evaluator_class = config_map["evaluator_class"]
config_class = config_map["config_class"]
config = config_class()
return evaluator_class(config)
except Exception:
logger.exception("Failed to create evaluation instance for framework: %s", framework)
return None
@staticmethod
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
"""Return supported metrics for the current framework and given category."""
instance = EvaluationManager.get_evaluation_instance()
if instance is None:
return []
return instance.get_supported_metrics(category)

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,299 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import DeepEvalConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
# deepeval metric field requirements (LLMTestCase fields):
# - faithfulness: input, actual_output, retrieval_context
# - answer_relevancy: input, actual_output
# - context_precision: input, actual_output, expected_output, retrieval_context
# - context_recall: input, actual_output, expected_output, retrieval_context
# - context_relevance: input, actual_output, retrieval_context
# - tool_correctness: input, actual_output, expected_tools
# - task_completion: input, actual_output
# Metrics not listed here are unsupported by deepeval and will be skipped.
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
}
class DeepEvalEvaluator(BaseEvaluationInstance):
"""DeepEval framework adapter for evaluation."""
def __init__(self, config: DeepEvalConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using DeepEval."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_deepeval(items, requested_metrics, category)
except ImportError:
logger.warning("DeepEval not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_deepeval(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using DeepEval library.
Builds LLMTestCase differently per category:
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
- Agent: input=query, actual_output=output
"""
metric_pairs = _build_deepeval_metrics(requested_metrics)
if not metric_pairs:
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
results: list[EvaluationItemResult] = []
for item in items:
test_case = self._build_test_case(item, category)
metrics: list[EvaluationMetric] = []
for canonical_name, metric in metric_pairs:
try:
metric.measure(test_case)
if metric.score is not None:
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
except Exception:
logger.exception(
"Failed to compute metric %s for item %d",
canonical_name,
item.index,
)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
@staticmethod
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a deepeval LLMTestCase with the correct fields per category."""
from deepeval.test_case import LLMTestCase
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
# faithfulness needs: input, actual_output, retrieval_context
# answer_relevancy needs: input, actual_output
return LLMTestCase(
input=user_input,
actual_output=item.output,
expected_output=item.expected_output or None,
retrieval_context=item.context or None,
)
case EvaluationCategory.RETRIEVAL:
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
return LLMTestCase(
input=user_input,
actual_output=item.output or "",
expected_output=item.expected_output or "",
retrieval_context=item.context or [],
)
case _:
return LLMTestCase(
input=user_input,
actual_output=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when DeepEval is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
"""Build DeepEval metric instances from canonical metric names.
Returns a list of (canonical_name, metric_instance) pairs so that callers
can record the canonical name rather than the framework-internal class name.
"""
try:
from deepeval.metrics import (
AnswerRelevancyMetric,
ContextualPrecisionMetric,
ContextualRecallMetric,
ContextualRelevancyMetric,
FaithfulnessMetric,
TaskCompletionMetric,
ToolCorrectnessMetric,
)
# Maps canonical name → deepeval metric class
deepeval_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
}
pairs: list[tuple[str, Any]] = []
for name in requested_metrics:
metric_class = deepeval_class_map.get(name)
if metric_class:
pairs.append((name, metric_class(threshold=0.5)))
else:
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
return pairs
except ImportError:
logger.warning("DeepEval metrics not available")
return []

View File

@ -0,0 +1,312 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
# Metrics not listed here are unsupported by ragas and will be skipped.
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
}
class RagasEvaluator(BaseEvaluationInstance):
"""RAGAS framework adapter for evaluation."""
def __init__(self, config: RagasConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using RAGAS."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
except ImportError:
logger.warning("RAGAS not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_ragas(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using RAGAS library.
Builds SingleTurnSample differently per category to match ragas requirements:
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
- Agent: Not supported via EvaluationDataset (requires message-based API)
"""
from ragas import evaluate as ragas_evaluate
from ragas.dataset_schema import EvaluationDataset
samples: list[Any] = []
for item in items:
sample = self._build_sample(item, category)
samples.append(sample)
dataset = EvaluationDataset(samples=samples)
ragas_metrics = self._build_ragas_metrics(requested_metrics)
if not ragas_metrics:
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
try:
result = ragas_evaluate(
dataset=dataset,
metrics=ragas_metrics,
)
results: list[EvaluationItemResult] = []
result_df = result.to_pandas()
for i, item in enumerate(items):
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
if m_name in result_df.columns:
score = result_df.iloc[i][m_name]
if score is not None and not (isinstance(score, float) and score != score):
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
except Exception:
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
@staticmethod
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a ragas SingleTurnSample with the correct fields per category.
ragas metric field requirements:
- faithfulness: user_input, response, retrieved_contexts
- answer_relevancy: user_input, response
- answer_correctness: user_input, response, reference
- semantic_similarity: user_input, response, reference
- context_precision: user_input, reference, retrieved_contexts
- context_recall: user_input, reference, retrieved_contexts
- context_relevance: user_input, retrieved_contexts
"""
from ragas.dataset_schema import SingleTurnSample
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM:
# response = actual LLM output, reference = expected output
return SingleTurnSample(
user_input=user_input,
response=item.output,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case EvaluationCategory.RETRIEVAL:
# context_precision/recall only need reference + retrieved_contexts
return SingleTurnSample(
user_input=user_input,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case _:
return SingleTurnSample(
user_input=user_input,
response=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when RAGAS is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
@staticmethod
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
"""Build RAGAS metric instances from canonical metric names."""
try:
from ragas.metrics.collections import (
AnswerCorrectness,
AnswerRelevancy,
ContextPrecision,
ContextRecall,
ContextRelevance,
Faithfulness,
SemanticSimilarity,
ToolCallAccuracy,
)
# Maps canonical name → ragas metric class
ragas_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: Faithfulness,
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancy,
EvaluationMetricName.ANSWER_CORRECTNESS: AnswerCorrectness,
EvaluationMetricName.SEMANTIC_SIMILARITY: SemanticSimilarity,
EvaluationMetricName.CONTEXT_PRECISION: ContextPrecision,
EvaluationMetricName.CONTEXT_RECALL: ContextRecall,
EvaluationMetricName.CONTEXT_RELEVANCE: ContextRelevance,
EvaluationMetricName.TOOL_CORRECTNESS: ToolCallAccuracy,
}
metrics = []
for name in requested_metrics:
metric_class = ragas_class_map.get(name)
if metric_class:
metrics.append(metric_class())
else:
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
return metrics
except ImportError:
logger.warning("RAGAS metrics not available")
return []
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""

View File

@ -0,0 +1,48 @@
import logging
from typing import Any
logger = logging.getLogger(__name__)
class DifyModelWrapper:
"""Wraps Dify's model invocation interface for use by RAGAS as an LLM judge.
RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.).
This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use.
"""
def __init__(self, model_provider: str, model_name: str, tenant_id: str):
self.model_provider = model_provider
self.model_name = model_name
self.tenant_id = tenant_id
def _get_model_instance(self) -> Any:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.model_provider,
model_type=ModelType.LLM,
model=self.model_name,
)
return model_instance
def invoke(self, prompt: str) -> str:
"""Invoke the model with a text prompt and return the text response."""
from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
model_instance = self._get_model_instance()
result = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
UserPromptMessage(content=prompt),
],
model_parameters={"temperature": 0.0, "max_tokens": 2048},
stream=False,
)
return result.message.content

View File

View File

@ -0,0 +1,160 @@
"""Judgment condition processor for evaluation metrics.
Evaluates pass/fail judgment conditions against evaluation metric values.
Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to
look up the metric value, then delegates the actual comparison to the
workflow condition engine (``graphon.utils.condition.processor``).
The processor is intentionally decoupled from evaluation frameworks and
runners. It operates on plain ``dict`` mappings and can be invoked
anywhere that already has per-item metric results.
"""
import logging
from collections.abc import Sequence
from typing import Any, cast
from core.evaluation.entities.judgment_entity import (
JudgmentCondition,
JudgmentConditionResult,
JudgmentConfig,
JudgmentResult,
)
from graphon.utils.condition.entities import SupportedComparisonOperator
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
logger = logging.getLogger(__name__)
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
class JudgmentProcessor:
@staticmethod
def evaluate(
metric_values: dict[tuple[str, str], Any],
config: JudgmentConfig,
) -> JudgmentResult:
"""Evaluate all judgment conditions against the given metric values.
Args:
metric_values: Mapping of ``(node_id, metric_name)`` metric
value (e.g. ``{("node_abc", "faithfulness"): 0.85}``).
config: The judgment configuration with logical_operator and
conditions.
Returns:
JudgmentResult with overall pass/fail and per-condition details.
"""
if not config.conditions:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=[],
)
condition_results: list[JudgmentConditionResult] = []
for condition in config.conditions:
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
condition_results.append(result)
if config.logical_operator == "and" and not result.passed:
return JudgmentResult(
passed=False,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "or" and result.passed:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "and":
final_passed = all(r.passed for r in condition_results)
else:
final_passed = any(r.passed for r in condition_results)
return JudgmentResult(
passed=final_passed,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
@staticmethod
def _evaluate_single_condition(
metric_values: dict[tuple[str, str], Any],
condition: JudgmentCondition,
) -> JudgmentConditionResult:
"""Evaluate a single judgment condition.
Steps:
1. Extract ``(node_id, metric_name)`` from ``variable_selector``.
2. Look up the metric value from ``metric_values``.
3. Delegate comparison to the workflow condition engine.
"""
selector = condition.variable_selector
if len(selector) < 2:
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"variable_selector must have at least 2 elements, got {len(selector)}",
)
node_id, metric_name = selector[0], selector[1]
actual_value = metric_values.get((node_id, metric_name))
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results",
)
try:
expected = condition.value
# Numeric operators need the actual value coerced to int/float
# so that the workflow engine's numeric assertions work correctly.
coerced_actual: object = actual_value
if (
condition.comparison_operator in {"=", "", ">", "<", "", ""}
and actual_value is not None
and not isinstance(actual_value, (int, float, bool))
):
coerced_actual = float(actual_value)
passed = _evaluate_condition(
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
value=coerced_actual,
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
)
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=expected,
actual_value=actual_value,
passed=passed,
)
except Exception as e:
logger.warning(
"Judgment condition evaluation failed for [%s, %s]: %s",
node_id,
metric_name,
str(e),
)
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=actual_value,
passed=False,
error=str(e),
)

View File

@ -0,0 +1,52 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, App, CustomizedSnippet, TenantAccountJoin
def get_service_account_for_app(session: Session, app_id: str) -> Account:
"""Get the creator account for an app with tenant context set up.
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
"""
app = session.scalar(select(App).where(App.id == app_id))
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account = session.scalar(select(Account).where(Account.id == app.created_by))
if not account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
"""Get the creator account for a snippet with tenant context set up.
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
"""
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
if not snippet:
raise ValueError(f"Snippet with id {snippet_id} not found")
if not snippet.created_by:
raise ValueError(f"Snippet with id {snippet_id} has no creator")
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
if not account:
raise ValueError(f"Creator account not found for snippet {snippet_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account

View File

@ -0,0 +1,62 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class AgentEvaluationRunner(BaseEvaluationRunner):
"""Runner for agent evaluation: collects tool calls and final output."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute agent evaluation metrics."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_agent(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_agent_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from agent NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -0,0 +1,51 @@
"""Base evaluation runner.
Provides the abstract interface for metric computation. Each concrete runner
(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics``
to compute scores for a specific node type.
Orchestration (merging results from multiple runners, applying judgment, and
persisting to the database) is handled by the evaluation task, not the runner.
"""
import logging
from abc import ABC, abstractmethod
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationItemResult,
)
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationRunner(ABC):
"""Abstract base class for evaluation runners.
Runners are stateless metric calculators: they receive node execution
results and a metric specification, then return scored results. They
do **not** touch the database or apply judgment logic.
"""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
self.evaluation_instance = evaluation_instance
@abstractmethod
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics on the collected results.
The returned ``EvaluationItemResult.index`` values are positional
(0-based) and correspond to the order of *node_run_result_list*.
The caller is responsible for mapping them back to the original
dataset indices.
"""
...

View File

@ -0,0 +1,83 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class LLMEvaluationRunner(BaseEvaluationRunner):
"""Runner for LLM evaluation: extracts prompts/outputs then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Use the evaluation instance to compute LLM metrics."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_llm(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(
items: list[NodeRunResult],
) -> list[EvaluationItemInput]:
"""Create new items from NodeRunResult for ragas evaluation.
Extracts prompts from process_data and concatenates them into a single
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
The last assistant message in outputs is used as the actual output.
"""
merged = []
for i, item in enumerate(items):
prompt = _format_prompts(item.process_data.get("prompts", []))
output = _extract_llm_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs={"prompt": prompt},
output=output,
)
)
return merged
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
"""Concatenate a list of prompt messages into a single string for evaluation.
Each message is formatted as "role: text" and joined with newlines.
"""
parts: list[str] = []
for msg in prompts:
role = msg.get("role", "unknown")
text = msg.get("text", "")
parts.append(f"{role}: {text}")
return "\n".join(parts)
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
"""Extract the LLM output text from NodeRunResult.outputs."""
if "text" in outputs:
return str(outputs["text"])
if "answer" in outputs:
return str(outputs["answer"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -0,0 +1,61 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class RetrievalEvaluationRunner(BaseEvaluationRunner):
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute retrieval evaluation metrics."""
if not node_run_result_list:
return []
merged_items = []
for i, node_result in enumerate(node_run_result_list):
outputs = node_result.outputs
query = self._extract_query(dict(node_result.inputs))
result_list = outputs.get("result", [])
contexts = [item.get("content", "") for item in result_list if item.get("content")]
output = "\n---\n".join(contexts)
merged_items.append(
EvaluationItemInput(
index=i,
inputs={"query": query},
output=output,
context=contexts,
)
)
return self.evaluation_instance.evaluate_retrieval(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""

View File

@ -0,0 +1,68 @@
"""Runner for Snippet evaluation.
Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from
the evaluation instance for metric computation.
"""
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class SnippetEvaluationRunner(BaseEvaluationRunner):
"""Runner for snippet evaluation: evaluates a published Snippet workflow."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics for snippet outputs."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_snippet_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from snippet NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -0,0 +1,62 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class WorkflowEvaluationRunner(BaseEvaluationRunner):
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute workflow evaluation metrics (end-to-end)."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_workflow_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from workflow NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -0,0 +1,45 @@
from flask_restx import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
# Snippet list item fields (lightweight for list display)
snippet_list_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"created_at": TimestampField,
"updated_at": TimestampField,
}
# Full snippet fields (includes creator info and graph data)
snippet_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"graph": fields.Raw(attribute="graph_dict"),
"input_fields": fields.Raw(attribute="input_fields_list"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
}
# Pagination response fields
snippet_pagination_fields = {
"data": fields.List(fields.Nested(snippet_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
}

View File

@ -14,6 +14,7 @@ workflow_app_log_partial_fields = {
"id": fields.String,
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
"details": fields.Raw(attribute="details"),
"evaluation": fields.Raw(attribute="evaluation", default=None),
"created_from": fields.String,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),

View File

@ -0,0 +1,83 @@
"""add_customized_snippets_table
Revision ID: 1c05e80d2380
Revises: 788d3099ae3a
Create Date: 2026-01-29 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = "1c05e80d2380"
down_revision = "788d3099ae3a"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
if _is_pg(conn):
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("graph", sa.Text(), nullable=True),
sa.Column("input_fields", sa.Text(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
else:
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", models.types.LongText(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
sa.Column("graph", models.types.LongText(), nullable=True),
sa.Column("input_fields", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
def downgrade():
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.drop_index("customized_snippet_tenant_idx")
op.drop_table("customized_snippets")

View File

@ -0,0 +1,116 @@
"""add_evaluation_tables
Revision ID: a1b2c3d4e5f6
Revises: 1c05e80d2380
Create Date: 2026-03-03 00:01:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "1c05e80d2380"
branch_labels = None
depends_on = None
def upgrade():
# evaluation_configurations
op.create_table(
"evaluation_configurations",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True),
sa.Column("evaluation_model", sa.String(length=255), nullable=True),
sa.Column("metrics_config", models.types.LongText(), nullable=True),
sa.Column("judgement_conditions", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.create_index(
"evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
# evaluation_runs
op.create_table(
"evaluation_runs",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")),
sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True),
sa.Column("result_file_id", models.types.StringUUID(), nullable=True),
sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("metrics_summary", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("celery_task_id", sa.String(length=255), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("completed_at", sa.DateTime(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
)
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.create_index(
"evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False)
# evaluation_run_items
op.create_table(
"evaluation_run_items",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
sa.Column("item_index", sa.Integer(), nullable=False),
sa.Column("inputs", models.types.LongText(), nullable=True),
sa.Column("expected_output", models.types.LongText(), nullable=True),
sa.Column("context", models.types.LongText(), nullable=True),
sa.Column("actual_output", models.types.LongText(), nullable=True),
sa.Column("metrics", models.types.LongText(), nullable=True),
sa.Column("metadata_json", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("overall_score", sa.Float(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
)
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False)
batch_op.create_index(
"evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False
)
batch_op.create_index("evaluation_run_item_workflow_run_idx", ["workflow_run_id"], unique=False)
def downgrade():
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_item_workflow_run_idx")
batch_op.drop_index("evaluation_run_item_index_idx")
batch_op.drop_index("evaluation_run_item_run_idx")
op.drop_table("evaluation_run_items")
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_status_idx")
batch_op.drop_index("evaluation_run_target_idx")
op.drop_table("evaluation_runs")
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.drop_index("evaluation_configuration_target_idx")
op.drop_table("evaluation_configurations")

View File

@ -0,0 +1,25 @@
"""merge migration heads
Revision ID: 4c60d8d3ee74
Revises: fce013ca180e, a1b2c3d4e5f6
Create Date: 2026-03-17 17:21:12.105536
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4c60d8d3ee74'
down_revision = ('fce013ca180e', 'a1b2c3d4e5f6')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

View File

@ -33,6 +33,13 @@ from .enums import (
WorkflowRunTriggeredFrom,
WorkflowTriggerStatus,
)
from .evaluation import (
EvaluationConfiguration,
EvaluationRun,
EvaluationRunItem,
EvaluationRunStatus,
EvaluationTargetType,
)
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
from .human_input import HumanInputForm
from .model import (
@ -80,6 +87,7 @@ from .provider import (
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .snippet import CustomizedSnippet, SnippetType
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
@ -139,6 +147,7 @@ __all__ = [
"Conversation",
"ConversationVariable",
"CreatorUserRole",
"CustomizedSnippet",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",
@ -156,6 +165,11 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"EvaluationConfiguration",
"EvaluationRun",
"EvaluationRunItem",
"EvaluationRunStatus",
"EvaluationTargetType",
"ExecutionExtraContent",
"ExporleBanner",
"ExternalKnowledgeApis",
@ -183,6 +197,7 @@ __all__ = [
"RecommendedApp",
"SavedMessage",
"Site",
"SnippetType",
"Tag",
"TagBinding",
"Tenant",

205
api/models/evaluation.py Normal file
View File

@ -0,0 +1,205 @@
from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, Float, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import Base
from .types import LongText, StringUUID
class EvaluationRunStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class EvaluationTargetType(StrEnum):
APP = "app"
SNIPPETS = "snippets"
KNOWLEDGE_BASE = "knowledge_base"
class EvaluationConfiguration(Base):
"""Stores evaluation configuration for each target (App or Snippet)."""
__tablename__ = "evaluation_configurations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.Index("evaluation_configuration_target_idx", "tenant_id", "target_type", "target_id"),
sa.Index("evaluation_configuration_workflow_idx", "customized_workflow_id"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
evaluation_model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
evaluation_model: Mapped[str | None] = mapped_column(String(255), nullable=True)
metrics_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
judgement_conditions: Mapped[str | None] = mapped_column(LongText, nullable=True)
customized_workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
updated_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def metrics_config_dict(self) -> dict[str, Any]:
if self.metrics_config:
return json.loads(self.metrics_config)
return {}
@metrics_config_dict.setter
def metrics_config_dict(self, value: dict[str, Any]) -> None:
self.metrics_config = json.dumps(value)
@property
def default_metrics_list(self) -> list[dict[str, Any]]:
"""Extract default_metrics from the stored metrics_config JSON."""
config = self.metrics_config_dict
return config.get("default_metrics", [])
@property
def customized_metrics_dict(self) -> dict[str, Any] | None:
"""Extract customized_metrics from the stored metrics_config JSON."""
config = self.metrics_config_dict
return config.get("customized_metrics")
@property
def judgment_config_dict(self) -> dict[str, Any] | None:
"""Return judgment config (stored in the judgement_conditions column)."""
if self.judgement_conditions:
parsed = json.loads(self.judgement_conditions)
return parsed if parsed else None
return None
@property
def judgement_conditions_dict(self) -> dict[str, Any]:
if self.judgement_conditions:
return json.loads(self.judgement_conditions)
return {}
@judgement_conditions_dict.setter
def judgement_conditions_dict(self, value: dict[str, Any]) -> None:
self.judgement_conditions = json.dumps(value)
def __repr__(self) -> str:
return f"<EvaluationConfiguration(id={self.id}, target={self.target_type}:{self.target_id})>"
class EvaluationRun(Base):
"""Stores each evaluation run record."""
__tablename__ = "evaluation_runs"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
sa.Index("evaluation_run_target_idx", "tenant_id", "target_type", "target_id"),
sa.Index("evaluation_run_status_idx", "tenant_id", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
evaluation_config_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
status: Mapped[str] = mapped_column(String(20), nullable=False, default=EvaluationRunStatus.PENDING)
dataset_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
result_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
total_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
completed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
failed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def progress(self) -> float:
if self.total_items == 0:
return 0.0
return (self.completed_items + self.failed_items) / self.total_items
def __repr__(self) -> str:
return f"<EvaluationRun(id={self.id}, status={self.status})>"
class EvaluationRunItem(Base):
"""Stores per-row evaluation results."""
__tablename__ = "evaluation_run_items"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
sa.Index("evaluation_run_item_run_idx", "evaluation_run_id"),
sa.Index("evaluation_run_item_index_idx", "evaluation_run_id", "item_index"),
sa.Index("evaluation_run_item_workflow_run_idx", "workflow_run_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
evaluation_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
item_index: Mapped[int] = mapped_column(Integer, nullable=False)
inputs: Mapped[str | None] = mapped_column(LongText, nullable=True)
expected_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
context: Mapped[str | None] = mapped_column(LongText, nullable=True)
actual_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
metrics: Mapped[str | None] = mapped_column(LongText, nullable=True)
judgment: Mapped[str | None] = mapped_column(LongText, nullable=True)
metadata_json: Mapped[str | None] = mapped_column(LongText, nullable=True)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@property
def inputs_dict(self) -> dict[str, Any]:
if self.inputs:
return json.loads(self.inputs)
return {}
@property
def metrics_list(self) -> list[dict[str, Any]]:
if self.metrics:
return json.loads(self.metrics)
return []
@property
def judgment_dict(self) -> dict[str, Any]:
if self.judgment:
return json.loads(self.judgment)
return {}
@property
def metadata_dict(self) -> dict[str, Any]:
if self.metadata_json:
return json.loads(self.metadata_json)
return {}
def __repr__(self) -> str:
return f"<EvaluationRunItem(id={self.id}, run={self.evaluation_run_id}, index={self.item_index})>"

101
api/models/snippet.py Normal file
View File

@ -0,0 +1,101 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .account import Account
from .base import Base
from .engine import db
from .types import AdjustedJSON, LongText, StringUUID
class SnippetType(StrEnum):
"""Snippet Type Enum"""
NODE = "node"
GROUP = "group"
class CustomizedSnippet(Base):
"""
Customized Snippet Model
Stores reusable workflow components (nodes or node groups) that can be
shared across applications within a workspace.
"""
__tablename__ = "customized_snippets"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
# Workflow reference for published version
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
# State flags
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
# Visual customization
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
# Snippet configuration (stored as JSON text)
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
# Audit fields
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def graph_dict(self) -> dict[str, Any]:
"""Get graph from associated workflow."""
if self.workflow_id:
from .workflow import Workflow
workflow = db.session.get(Workflow, self.workflow_id)
if workflow:
return json.loads(workflow.graph) if workflow.graph else {}
return {}
@property
def input_fields_list(self) -> list[dict[str, Any]]:
"""Parse input_fields JSON to list."""
return json.loads(self.input_fields) if self.input_fields else []
@property
def created_by_account(self) -> Account | None:
"""Get the account that created this snippet."""
if self.created_by:
return db.session.get(Account, self.created_by)
return None
@property
def updated_by_account(self) -> Account | None:
"""Get the account that last updated this snippet."""
if self.updated_by:
return db.session.get(Account, self.updated_by)
return None
@property
def version_str(self) -> str:
"""Get version as string for API response."""
return str(self.version)

View File

@ -106,6 +106,8 @@ class WorkflowType(StrEnum):
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
SNIPPET = "snippet"
EVALUATION = "evaluation"
@classmethod
def value_of(cls, value: str) -> "WorkflowType":

View File

@ -234,6 +234,12 @@ storage = [
############################################################
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
############################################################
# [ Evaluation ] dependency group
# Required for evaluation frameworks
############################################################
evaluation = ["ragas>=0.2.0", "deepeval>=2.0.0"]
############################################################
# [ VDB ] workspace plugins — hollow packages under providers/vdb/*
# Each declares its own third-party deps and registers dify.vector_backends entry points.

View File

@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -116,6 +117,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)

View File

@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -131,9 +132,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -153,13 +155,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED

View File

@ -32,6 +32,102 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
@ -149,11 +245,63 @@ class BillingService:
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}

View File

@ -0,0 +1,21 @@
from services.errors.base import BaseServiceError
class EvaluationFrameworkNotConfiguredError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Evaluation framework is not configured. Set EVALUATION_FRAMEWORK env var.")
class EvaluationNotFoundError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Evaluation not found.")
class EvaluationDatasetInvalidError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Evaluation dataset is invalid.")
class EvaluationMaxConcurrentRunsError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Maximum number of concurrent evaluation runs reached.")

View File

@ -0,0 +1,985 @@
import io
import json
import logging
from collections.abc import Mapping
from typing import Any, Union
from openpyxl import Workbook, load_workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from sqlalchemy.orm import Session
from configs import dify_config
from core.evaluation.entities.evaluation_entity import (
METRIC_NODE_TYPE_MAPPING,
METRIC_VALUE_TYPE_MAPPING,
DefaultMetric,
EvaluationCategory,
EvaluationConfigData,
EvaluationDatasetInput,
EvaluationMetricName,
EvaluationRunData,
EvaluationRunRequest,
NodeInfo,
)
from core.evaluation.evaluation_manager import EvaluationManager
from graphon.enums import WorkflowNodeExecutionMetadataKey
from graphon.node_events.base import NodeRunResult
from models.evaluation import (
EvaluationConfiguration,
EvaluationRun,
EvaluationRunItem,
EvaluationRunStatus,
)
from models.model import App, AppMode
from models.snippet import CustomizedSnippet
from models.workflow import Workflow
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class EvaluationService:
"""
Service for evaluation-related operations.
Provides functionality to generate evaluation dataset templates
based on App or Snippet input parameters.
"""
# Excluded app modes that don't support evaluation templates
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
@classmethod
def generate_dataset_template(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> tuple[bytes, str]:
"""
Generate evaluation dataset template as XLSX bytes.
Creates an XLSX file with headers based on the evaluation target's input parameters.
The first column is index, followed by input parameter columns.
:param target: App or CustomizedSnippet instance
:param target_type: Target type string ("app" or "snippet")
:return: Tuple of (xlsx_content_bytes, filename)
:raises ValueError: If target type is not supported or app mode is excluded
"""
# Validate target type
if target_type == "app":
if not isinstance(target, App):
raise ValueError("Invalid target: expected App instance")
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
input_fields = cls._get_app_input_fields(target)
elif target_type == "snippet":
if not isinstance(target, CustomizedSnippet):
raise ValueError("Invalid target: expected CustomizedSnippet instance")
input_fields = cls._get_snippet_input_fields(target)
else:
raise ValueError(f"Unsupported target type: {target_type}")
# Generate XLSX template
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
# Build filename
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
filename = f"{truncated_name}-evaluation-dataset.xlsx"
return xlsx_content, filename
@classmethod
def _get_app_input_fields(cls, app: App) -> list[dict]:
"""
Get input fields from App's workflow.
:param app: App instance
:return: List of input field definitions
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
workflow = workflow_service.get_draft_workflow(app_model=app)
if not workflow:
return []
# Get user input form from workflow
user_input_form = workflow.user_input_form()
return user_input_form
@classmethod
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
"""
Get input fields from Snippet.
Tries to get from snippet's own input_fields first,
then falls back to workflow's user_input_form.
:param snippet: CustomizedSnippet instance
:return: List of input field definitions
"""
# Try snippet's own input_fields first
input_fields = snippet.input_fields_list
if input_fields:
return input_fields
# Fallback to workflow's user_input_form
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
if not workflow:
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow:
return workflow.user_input_form()
return []
@classmethod
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
"""
Generate XLSX template file content.
Creates a workbook with:
- First row as header row with "index" and input field names
- Styled header with background color and borders
- Empty data rows ready for user input
:param input_fields: List of input field definitions
:param target_name: Name of the target (for sheet name)
:return: XLSX file content as bytes
"""
wb = Workbook()
ws = wb.active
if ws is None:
ws = wb.create_sheet("Evaluation Dataset")
sheet_name = "Evaluation Dataset"
ws.title = sheet_name
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
# Build header row
headers = ["index"]
for field in input_fields:
field_label = str(field.get("label") or field.get("variable") or "")
headers.append(field_label)
# Write header row
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
# Set column widths
ws.column_dimensions["A"].width = 10 # index column
for col_idx in range(2, len(headers) + 1):
ws.column_dimensions[get_column_letter(col_idx)].width = 20
# Add one empty row with row number for user reference
for col_idx in range(1, len(headers) + 1):
cell = ws.cell(row=2, column=col_idx, value="")
cell.border = thin_border
if col_idx == 1:
cell.value = 1
cell.alignment = Alignment(horizontal="center")
# Save to bytes
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue()
@classmethod
def generate_retrieval_dataset_template(cls) -> tuple[bytes, str]:
"""Generate evaluation dataset XLSX template for knowledge base retrieval.
The template contains three columns: ``index``, ``query``, and
``expected_output``. Callers upload a filled copy and start an
evaluation run with ``target_type="dataset"``.
:returns: (xlsx_content_bytes, filename)
"""
wb = Workbook()
ws = wb.active
if ws is None:
ws = wb.create_sheet("Evaluation Dataset")
ws.title = "Evaluation Dataset"
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
headers = ["index", "query", "expected_output"]
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
ws.column_dimensions["A"].width = 10
ws.column_dimensions["B"].width = 30
ws.column_dimensions["C"].width = 30
# Add one sample row
for col_idx in range(1, len(headers) + 1):
cell = ws.cell(row=2, column=col_idx, value="")
cell.border = thin_border
if col_idx == 1:
cell.value = 1
cell.alignment = Alignment(horizontal="center")
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue(), "retrieval-evaluation-dataset.xlsx"
# ---- Evaluation Configuration CRUD ----
@classmethod
def get_evaluation_config(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
) -> EvaluationConfiguration | None:
return (
session.query(EvaluationConfiguration)
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
.first()
)
@classmethod
def save_evaluation_config(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
account_id: str,
data: EvaluationConfigData,
) -> EvaluationConfiguration:
config = cls.get_evaluation_config(session, tenant_id, target_type, target_id)
if config is None:
config = EvaluationConfiguration(
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
created_by=account_id,
updated_by=account_id,
)
session.add(config)
config.evaluation_model_provider = data.evaluation_model_provider
config.evaluation_model = data.evaluation_model
config.metrics_config = json.dumps(
{
"default_metrics": [m.model_dump() for m in data.default_metrics],
"customized_metrics": data.customized_metrics.model_dump() if data.customized_metrics else None,
}
)
config.judgement_conditions = json.dumps(data.judgment_config.model_dump() if data.judgment_config else {})
config.customized_workflow_id = (
data.customized_metrics.evaluation_workflow_id if data.customized_metrics else None
)
config.updated_by = account_id
session.commit()
session.refresh(config)
return config
@classmethod
def list_targets_by_customized_workflow(
cls,
session: Session,
tenant_id: str,
customized_workflow_id: str,
) -> list[EvaluationConfiguration]:
"""Return all evaluation configs that reference the given workflow as customized metrics."""
from sqlalchemy import select
return list(
session.scalars(
select(EvaluationConfiguration).where(
EvaluationConfiguration.tenant_id == tenant_id,
EvaluationConfiguration.customized_workflow_id == customized_workflow_id,
)
).all()
)
# ---- Evaluation Run Management ----
@classmethod
def start_evaluation_run(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
account_id: str,
dataset_file_content: bytes,
run_request: EvaluationRunRequest,
) -> EvaluationRun:
"""Validate dataset, create run record, dispatch Celery task.
Saves the provided parameters as the latest EvaluationConfiguration
before creating the run.
"""
# Check framework is configured
evaluation_instance = EvaluationManager.get_evaluation_instance()
if evaluation_instance is None:
raise EvaluationFrameworkNotConfiguredError()
# Save as latest EvaluationConfiguration
config = cls.save_evaluation_config(
session=session,
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
account_id=account_id,
data=run_request,
)
# Check concurrent run limit
active_runs = (
session.query(EvaluationRun)
.filter_by(tenant_id=tenant_id)
.filter(EvaluationRun.status.in_([EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING]))
.count()
)
max_concurrent = dify_config.EVALUATION_MAX_CONCURRENT_RUNS
if active_runs >= max_concurrent:
raise EvaluationMaxConcurrentRunsError(f"Maximum concurrent runs ({max_concurrent}) reached.")
# Parse dataset
items = cls._parse_dataset(dataset_file_content)
max_rows = dify_config.EVALUATION_MAX_DATASET_ROWS
if len(items) > max_rows:
raise EvaluationDatasetInvalidError(f"Dataset has {len(items)} rows, max is {max_rows}.")
# Create evaluation run
evaluation_run = EvaluationRun(
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
evaluation_config_id=config.id,
status=EvaluationRunStatus.PENDING,
total_items=len(items),
created_by=account_id,
)
session.add(evaluation_run)
session.commit()
session.refresh(evaluation_run)
# Build Celery task data
run_data = EvaluationRunData(
evaluation_run_id=evaluation_run.id,
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
evaluation_model_provider=run_request.evaluation_model_provider,
evaluation_model=run_request.evaluation_model,
default_metrics=run_request.default_metrics,
customized_metrics=run_request.customized_metrics,
judgment_config=run_request.judgment_config,
input_list=items,
)
# Dispatch Celery task
from tasks.evaluation_task import run_evaluation
task = run_evaluation.delay(run_data.model_dump())
evaluation_run.celery_task_id = task.id
session.commit()
return evaluation_run
@classmethod
def get_evaluation_runs(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[EvaluationRun], int]:
"""Query evaluation run history with pagination."""
query = (
session.query(EvaluationRun)
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
.order_by(EvaluationRun.created_at.desc())
)
total = query.count()
runs = query.offset((page - 1) * page_size).limit(page_size).all()
return runs, total
@classmethod
def get_evaluation_run_detail(
cls,
session: Session,
tenant_id: str,
run_id: str,
) -> EvaluationRun:
run = session.query(EvaluationRun).filter_by(id=run_id, tenant_id=tenant_id).first()
if not run:
raise EvaluationNotFoundError("Evaluation run not found.")
return run
@classmethod
def get_evaluation_run_items(
cls,
session: Session,
run_id: str,
page: int = 1,
page_size: int = 50,
) -> tuple[list[EvaluationRunItem], int]:
"""Query evaluation run items with pagination."""
query = (
session.query(EvaluationRunItem)
.filter_by(evaluation_run_id=run_id)
.order_by(EvaluationRunItem.item_index.asc())
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
@classmethod
def cancel_evaluation_run(
cls,
session: Session,
tenant_id: str,
run_id: str,
) -> EvaluationRun:
run = cls.get_evaluation_run_detail(session, tenant_id, run_id)
if run.status not in (EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING):
raise ValueError(f"Cannot cancel evaluation run in status: {run.status}")
run.status = EvaluationRunStatus.CANCELLED
# Revoke Celery task if running
if run.celery_task_id:
try:
from celery import current_app as celery_app
celery_app.control.revoke(run.celery_task_id, terminate=True)
except Exception:
logger.exception("Failed to revoke Celery task %s", run.celery_task_id)
session.commit()
return run
@classmethod
def get_supported_metrics(cls, category: EvaluationCategory) -> list[str]:
return EvaluationManager.get_supported_metrics(category)
@staticmethod
def get_available_metrics() -> list[str]:
"""Return the centrally-defined list of evaluation metrics."""
return [m.value for m in EvaluationMetricName]
@classmethod
def _nodes_for_metrics_from_workflow(
cls,
workflow: Workflow,
metrics: list[str],
) -> dict[str, list[dict[str, str]]]:
node_type_to_nodes: dict[str, list[dict[str, str]]] = {}
for node_id, node_data in workflow.walk_nodes():
ntype = node_data.get("type", "")
node_type_to_nodes.setdefault(ntype, []).append(
NodeInfo(node_id=node_id, type=ntype, title=node_data.get("title", "")).model_dump()
)
result: dict[str, list[dict[str, str]]] = {}
for metric in metrics:
required_node_type = METRIC_NODE_TYPE_MAPPING.get(metric)
if required_node_type is None:
result[metric] = []
continue
result[metric] = node_type_to_nodes.get(required_node_type, [])
return result
@classmethod
def _union_supported_metric_names(cls) -> list[str]:
"""Metric names the current evaluation framework supports for any :class:`EvaluationCategory`."""
ordered: list[str] = []
seen: set[str] = set()
for category in EvaluationCategory:
for name in cls.get_supported_metrics(category):
if name not in seen:
seen.add(name)
ordered.append(name)
return ordered
@classmethod
def get_default_metrics_with_nodes_for_published_target(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> list[DefaultMetric]:
"""List default metrics and matching nodes using only the *published* workflow graph.
Metrics are those supported by the configured evaluation framework and present in
:data:`METRIC_NODE_TYPE_MAPPING`. Node lists are derived from the published workflow only
(no draft fallback).
"""
workflow = cls._resolve_published_workflow(target, target_type)
if not workflow:
return []
supported = cls._union_supported_metric_names()
metric_names = sorted(m for m in supported if m in METRIC_NODE_TYPE_MAPPING)
if not metric_names:
return []
nodes_by_metric = cls._nodes_for_metrics_from_workflow(workflow, metric_names)
return [
DefaultMetric(
metric=m,
value_type=METRIC_VALUE_TYPE_MAPPING.get(m, "number"),
node_info_list=[NodeInfo.model_validate(n) for n in nodes_by_metric.get(m, [])],
)
for m in metric_names
]
@classmethod
def get_nodes_for_metrics(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
metrics: list[str] | None = None,
) -> dict[str, list[dict[str, str]]]:
"""Return node info grouped by metric (or all nodes when *metrics* is empty).
:param target: App or CustomizedSnippet instance.
:param target_type: ``"app"`` or ``"snippets"``.
:param metrics: Optional list of metric names to filter by.
When *None* or empty, returns ``{"all": [<every node>]}``.
:returns: ``{metric_name: [NodeInfo dict, ...]}`` or
``{"all": [NodeInfo dict, ...]}``.
"""
workflow = cls._resolve_workflow(target, target_type)
if not workflow:
return {"all": []} if not metrics else {m: [] for m in metrics}
if not metrics:
all_nodes = [
NodeInfo(node_id=node_id, type=node_data.get("type", ""), title=node_data.get("title", "")).model_dump()
for node_id, node_data in workflow.walk_nodes()
]
return {"all": all_nodes}
return cls._nodes_for_metrics_from_workflow(workflow, metrics)
@classmethod
def _resolve_published_workflow(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> Workflow | None:
"""Resolve only the published workflow for the target (no draft fallback)."""
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
return SnippetService().get_published_workflow(snippet=target)
if target_type == "app" and isinstance(target, App):
return WorkflowService().get_published_workflow(app_model=target)
return None
@classmethod
def _resolve_workflow(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> Workflow | None:
"""Resolve the *published* (preferred) or *draft* workflow for the target."""
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=target)
if not workflow:
workflow = snippet_service.get_draft_workflow(snippet=target)
return workflow
elif target_type == "app" and isinstance(target, App):
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=target)
if not workflow:
workflow = workflow_service.get_draft_workflow(app_model=target)
return workflow
return None
# ---- Category Resolution ----
@classmethod
def _resolve_evaluation_category(cls, default_metrics: list[DefaultMetric]) -> EvaluationCategory:
"""Derive evaluation category from default_metrics node_info types.
Uses the type of the first node_info found in default_metrics.
Falls back to LLM if no metrics are provided.
"""
for metric in default_metrics:
for node_info in metric.node_info_list:
try:
return EvaluationCategory(node_info.type)
except ValueError:
continue
return EvaluationCategory.LLM
@classmethod
def execute_targets(
cls,
tenant_id: str,
target_type: str,
target_id: str,
input_list: list[EvaluationDatasetInput],
max_workers: int = 5,
) -> tuple[list[dict[str, NodeRunResult]], list[str | None]]:
"""Execute the evaluation target for every test-data item in parallel.
:param tenant_id: Workspace / tenant ID.
:param target_type: ``"app"`` or ``"snippet"``.
:param target_id: ID of the App or CustomizedSnippet.
:param input_list: All test-data items parsed from the dataset.
:param max_workers: Maximum number of parallel worker threads.
:return: Tuple of (node_results, workflow_run_ids).
node_results: ordered list of ``{node_id: NodeRunResult}`` mappings;
the *i*-th element corresponds to ``input_list[i]``.
workflow_run_ids: ordered list of workflow_run_id strings (or None)
for each input item.
"""
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, current_app
flask_app: Flask = current_app._get_current_object() # type: ignore
def _worker(item: EvaluationDatasetInput) -> tuple[dict[str, NodeRunResult], str | None]:
with flask_app.app_context():
from models.engine import db
with Session(db.engine, expire_on_commit=False) as thread_session:
try:
response = cls._run_single_target(
session=thread_session,
target_type=target_type,
target_id=target_id,
item=item,
)
workflow_run_id = cls._extract_workflow_run_id(response)
if not workflow_run_id:
logger.warning(
"No workflow_run_id for item %d (target=%s)",
item.index,
target_id,
)
return {}, None
node_results = cls._query_node_run_results(
session=thread_session,
tenant_id=tenant_id,
app_id=target_id,
workflow_run_id=workflow_run_id,
)
return node_results, workflow_run_id
except Exception:
logger.exception(
"Target execution failed for item %d (target=%s)",
item.index,
target_id,
)
return {}, None
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_worker, item) for item in input_list]
ordered_results: list[dict[str, NodeRunResult]] = []
ordered_workflow_run_ids: list[str | None] = []
for future in futures:
try:
node_result, wf_run_id = future.result()
ordered_results.append(node_result)
ordered_workflow_run_ids.append(wf_run_id)
except Exception:
logger.exception("Unexpected error collecting target execution result")
ordered_results.append({})
ordered_workflow_run_ids.append(None)
return ordered_results, ordered_workflow_run_ids
@classmethod
def _run_single_target(
cls,
session: Session,
target_type: str,
target_id: str,
item: EvaluationDatasetInput,
) -> Mapping[str, object]:
"""Execute a single evaluation target with one test-data item.
Dispatches to the appropriate execution service based on
``target_type``:
* ``"snippet"`` :meth:`SnippetGenerateService.run_published`
* ``"app"`` :meth:`WorkflowAppGenerator().generate` (blocking mode)
:returns: The blocking response mapping from the workflow engine.
:raises ValueError: If the target is not found or not published.
"""
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app, get_service_account_for_snippet
if target_type == "snippet":
from services.snippet_generate_service import SnippetGenerateService
snippet = session.query(CustomizedSnippet).filter_by(id=target_id).first()
if not snippet:
raise ValueError(f"Snippet {target_id} not found")
service_account = get_service_account_for_snippet(session, target_id)
return SnippetGenerateService.run_published(
snippet=snippet,
user=service_account,
args={"inputs": item.inputs},
invoke_from=InvokeFrom.SERVICE_API,
)
else:
# target_type == "app"
app = session.query(App).filter_by(id=target_id).first()
if not app:
raise ValueError(f"App {target_id} not found")
service_account = get_service_account_for_app(session, target_id)
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
raise ValueError(f"No published workflow for app {target_id}")
response: Mapping[str, object] = WorkflowAppGenerator().generate(
app_model=app,
workflow=workflow,
user=service_account,
args={"inputs": item.inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
)
return response
@staticmethod
def _extract_workflow_run_id(response: Mapping[str, object]) -> str | None:
"""Extract ``workflow_run_id`` from a blocking workflow response."""
wf_run_id = response.get("workflow_run_id")
if wf_run_id:
return str(wf_run_id)
data = response.get("data")
if isinstance(data, Mapping) and data.get("id"):
return str(data["id"])
return None
@staticmethod
def _query_node_run_results(
session: Session,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> dict[str, NodeRunResult]:
"""Query all node execution records for a workflow run."""
from sqlalchemy import asc, select
from graphon.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
stmt = (
WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel))
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(asc(WorkflowNodeExecutionModel.created_at))
)
node_models: list[WorkflowNodeExecutionModel] = list(session.execute(stmt).scalars().all())
result: dict[str, NodeRunResult] = {}
for node in node_models:
# Convert string-keyed metadata to WorkflowNodeExecutionMetadataKey-keyed
raw_metadata = node.execution_metadata_dict
typed_metadata: dict[WorkflowNodeExecutionMetadataKey, object] = {}
for key, val in raw_metadata.items():
try:
typed_metadata[WorkflowNodeExecutionMetadataKey(key)] = val
except ValueError:
pass # skip unknown metadata keys
result[node.node_id] = NodeRunResult(
status=WorkflowNodeExecutionStatus(node.status),
inputs=node.inputs_dict or {},
process_data=node.process_data_dict or {},
outputs=node.outputs_dict or {},
metadata=typed_metadata,
error=node.error or "",
)
return result
# ---- Dataset Parsing ----
@classmethod
def _parse_dataset(cls, xlsx_content: bytes) -> list[EvaluationDatasetInput]:
"""Parse evaluation dataset from XLSX bytes."""
wb = load_workbook(io.BytesIO(xlsx_content), read_only=True)
ws = wb.active
if ws is None:
raise EvaluationDatasetInvalidError("XLSX file has no active worksheet.")
rows = list(ws.iter_rows(values_only=True))
if len(rows) < 2:
raise EvaluationDatasetInvalidError("Dataset must have at least a header row and one data row.")
headers = [str(h).strip() if h is not None else "" for h in rows[0]]
if not headers or headers[0].lower() != "index":
raise EvaluationDatasetInvalidError("First column header must be 'index'.")
input_headers = headers[1:] # Skip 'index'
items = []
for row_idx, row in enumerate(rows[1:], start=1):
values = list(row)
if all(v is None or str(v).strip() == "" for v in values):
continue # Skip empty rows
index_val = values[0] if values else row_idx
try:
index = int(str(index_val))
except (TypeError, ValueError):
index = row_idx
inputs: dict[str, Any] = {}
for col_idx, header in enumerate(input_headers):
val = values[col_idx + 1] if col_idx + 1 < len(values) else None
inputs[header] = str(val) if val is not None else ""
# Extract expected_output column into dedicated field
expected_output = inputs.pop("expected_output", None)
items.append(
EvaluationDatasetInput(
index=index,
inputs=inputs,
expected_output=expected_output,
)
)
wb.close()
return items
@classmethod
def execute_retrieval_test_targets(
cls,
dataset_id: str,
account_id: str,
input_list: list[EvaluationDatasetInput],
max_workers: int = 5,
) -> list[NodeRunResult]:
"""Run hit testing against a knowledge base for every input item in parallel.
Each item must supply a ``query`` key in its ``inputs`` dict. The
retrieved segments are normalised into the same ``NodeRunResult`` format
that :class:`RetrievalEvaluationRunner` expects:
.. code-block:: python
NodeRunResult(
inputs={"query": "..."},
outputs={"result": [{"content": "...", "score": ...}, ...]},
)
:returns: Ordered list of ``NodeRunResult`` one per input item.
If retrieval fails for an item the result has an empty ``result``
list so the runner can still persist a (metric-less) row.
"""
from concurrent.futures import ThreadPoolExecutor
from flask import current_app
flask_app = current_app._get_current_object() # type: ignore
def _worker(item: EvaluationDatasetInput) -> NodeRunResult:
with flask_app.app_context():
from extensions.ext_database import db as flask_db
from models.account import Account
from models.dataset import Dataset
from services.hit_testing_service import HitTestingService
dataset = flask_db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError(f"Dataset {dataset_id} not found")
account = flask_db.session.query(Account).filter_by(id=account_id).first()
if not account:
raise ValueError(f"Account {account_id} not found")
query = str(item.inputs.get("query", ""))
response = HitTestingService.retrieve(
dataset=dataset,
query=query,
account=account,
retrieval_model=None, # Use dataset's configured retrieval model
external_retrieval_model={},
limit=10,
)
records = response.get("records", [])
result_list = [
{
"content": r.get("segment", {}).get("content", "") or r.get("content", ""),
"score": r.get("score"),
}
for r in records
if r.get("segment", {}).get("content") or r.get("content")
]
return NodeRunResult(
inputs={"query": query},
outputs={"result": result_list},
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_worker, item) for item in input_list]
results: list[NodeRunResult] = []
for item, future in zip(input_list, futures):
try:
results.append(future.result())
except Exception:
logger.exception("Retrieval test failed for item %d (dataset=%s)", item.index, dataset_id)
results.append(NodeRunResult(inputs={}, outputs={"result": []}))
return results

View File

@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -0,0 +1,555 @@
import json
import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from urllib.parse import urlparse
import yaml # type: ignore
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.plugin.entities.plugin import PluginDependency
from extensions.ext_redis import redis_client
from graphon.enums import BuiltinNodeTypes
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models import Account
from models.snippet import CustomizedSnippet, SnippetType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.snippet_service import SNIPPET_FORBIDDEN_NODE_TYPES, SnippetService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "snippet_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "snippet_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class SnippetImportInfo(BaseModel):
id: str
status: ImportStatus
snippet_id: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# If imported version is newer than current, always return PENDING
if imported_ver > current_ver:
return ImportStatus.PENDING
# If imported version is older than current's major, return PENDING
if imported_ver.major < current_ver.major:
return ImportStatus.PENDING
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
# If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
class SnippetPendingData(BaseModel):
import_mode: str
yaml_content: str
snippet_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
snippet_id: str | None
class SnippetDslService:
def __init__(self, session: Session):
self._session = session
def import_snippet(
self,
*,
account: Account,
import_mode: str,
yaml_content: str | None = None,
yaml_url: str | None = None,
snippet_id: str | None = None,
name: str | None = None,
description: str | None = None,
) -> SnippetImportInfo:
"""Import a snippet from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if parsed_url.scheme not in ["http", "https"]:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid URL scheme, only http and https are allowed",
)
response = ssrf_proxy.get(yaml_url, timeout=(10, 30))
if response.status_code != 200:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Failed to fetch YAML from URL: {response.status_code}",
)
content = response.text
if len(content) > DSL_MAX_SIZE:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes",
)
except Exception as e:
logger.exception("Failed to fetch YAML from URL")
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Failed to fetch YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
if len(content) > DSL_MAX_SIZE:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes",
)
try:
# Parse YAML
data = yaml.safe_load(content)
if not isinstance(data, dict):
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: expected a dictionary",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
# Strictly validate kind field
kind = data.get("kind")
if not kind:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing 'kind' field in DSL. Expected 'kind: snippet'.",
)
if kind != "snippet":
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid DSL kind: expected 'snippet', got '{kind}'. This DSL is for {kind}, not snippet.",
)
imported_version = data.get("version", "0.1.0")
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract snippet data
snippet_data = data.get("snippet")
if not snippet_data:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing snippet data in YAML content",
)
# Validate workflow nodes - check for forbidden node types
workflow_data = data.get("workflow", {})
if workflow_data:
graph = workflow_data.get("graph", {})
nodes = graph.get("nodes", [])
forbidden_nodes_found = []
for node in nodes:
node_data = node.get("data", {})
if not node_data:
continue
node_type = node_data.get("type", "")
if node_type in SNIPPET_FORBIDDEN_NODE_TYPES:
forbidden_nodes_found.append(node_type)
if forbidden_nodes_found:
forbidden_types_str = ", ".join(set(forbidden_nodes_found))
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Snippet cannot contain the following node types: {forbidden_types_str}",
)
# If snippet_id is provided, check if it exists
snippet = None
if snippet_id:
stmt = select(CustomizedSnippet).where(
CustomizedSnippet.id == snippet_id,
CustomizedSnippet.tenant_id == account.current_tenant_id,
)
snippet = self._session.scalar(stmt)
if not snippet:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Snippet not found",
)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = SnippetPendingData(
import_mode=import_mode,
yaml_content=content,
snippet_id=snippet_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return SnippetImportInfo(
id=import_id,
status=status,
snippet_id=snippet_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update snippet
snippet = self._create_or_update_snippet(
snippet=snippet,
data=data,
account=account,
name=name,
description=description,
dependencies=check_dependencies_pending_data,
)
return SnippetImportInfo(
id=import_id,
status=status,
snippet_id=snippet.id,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import snippet")
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> SnippetImportInfo:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data
pending = SnippetPendingData.model_validate_json(pending_data_str)
# Re-import with the pending data
return self.import_snippet(
account=account,
import_mode=pending.import_mode,
yaml_content=pending.yaml_content,
snippet_id=pending.snippet_id,
)
except Exception as e:
logger.exception("Failed to confirm import")
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(self, snippet: CustomizedSnippet) -> CheckDependenciesResult:
"""
Check dependencies for a snippet
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
return CheckDependenciesResult(leaked_dependencies=[])
dependencies = self._extract_dependencies_from_workflow(workflow)
leaked_dependencies = DependenciesAnalysisService.generate_dependencies(
tenant_id=snippet.tenant_id, dependencies=dependencies
)
return CheckDependenciesResult(leaked_dependencies=leaked_dependencies)
def _create_or_update_snippet(
self,
*,
snippet: CustomizedSnippet | None,
data: dict,
account: Account,
name: str | None = None,
description: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> CustomizedSnippet:
"""
Create or update snippet from DSL data
"""
snippet_data = data.get("snippet", {})
workflow_data = data.get("workflow", {})
# Extract snippet info
snippet_name = name or snippet_data.get("name") or "Untitled Snippet"
snippet_description = description or snippet_data.get("description") or ""
snippet_type_str = snippet_data.get("type", "node")
try:
snippet_type = SnippetType(snippet_type_str)
except ValueError:
snippet_type = SnippetType.NODE
icon_info = snippet_data.get("icon_info", {})
input_fields = snippet_data.get("input_fields", [])
# Create or update snippet
if snippet:
# Update existing snippet
snippet.name = snippet_name
snippet.description = snippet_description
snippet.type = snippet_type.value
snippet.icon_info = icon_info or None
snippet.input_fields = json.dumps(input_fields) if input_fields else None
snippet.updated_by = account.id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
else:
# Create new snippet
snippet = CustomizedSnippet(
tenant_id=account.current_tenant_id,
name=snippet_name,
description=snippet_description,
type=snippet_type.value,
icon_info=icon_info or None,
input_fields=json.dumps(input_fields) if input_fields else None,
created_by=account.id,
)
self._session.add(snippet)
self._session.flush()
# Create or update draft workflow
if workflow_data:
graph = workflow_data.get("graph", {})
snippet_service = SnippetService()
# Get existing workflow hash if exists
existing_workflow = snippet_service.get_draft_workflow(snippet=snippet)
unique_hash = existing_workflow.unique_hash if existing_workflow else None
snippet_service.sync_draft_workflow(
snippet=snippet,
graph=graph,
unique_hash=unique_hash,
account=account,
input_fields=input_fields,
)
self._session.commit()
return snippet
def export_snippet_dsl(self, snippet: CustomizedSnippet, include_secret: bool = False) -> str:
"""
Export snippet as DSL
:param snippet: CustomizedSnippet instance
:param include_secret: Whether include secret variable
:return: YAML string
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
icon_info = snippet.icon_info or {}
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "snippet",
"snippet": {
"name": snippet.name,
"description": snippet.description or "",
"type": snippet.type,
"icon_info": icon_info,
"input_fields": snippet.input_fields_list,
},
}
self._append_workflow_export_data(
export_data=export_data, snippet=snippet, workflow=workflow, include_secret=include_secret
)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
def _append_workflow_export_data(
self, *, export_data: dict, snippet: CustomizedSnippet, workflow: Workflow, include_secret: bool
) -> None:
"""
Append workflow export data
"""
workflow_dict = workflow.to_dict(include_secret=include_secret)
# Filter workspace related data from nodes
workflow_dict["environment_variables"] = []
workflow_dict["conversation_variables"] = []
for node in workflow_dict.get("graph", {}).get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node_data.get("dataset_ids", [])
node["data"]["dataset_ids"] = [
self._encrypt_dataset_id(dataset_id=dataset_id, tenant_id=snippet.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == BuiltinNodeTypes.TOOL:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == BuiltinNodeTypes.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=snippet.tenant_id, dependencies=dependencies
)
]
def _encrypt_dataset_id(self, *, dataset_id: str, tenant_id: str) -> str:
"""
Encrypt dataset ID for export
"""
# For now, just return the dataset_id as-is
# In the future, we might want to encrypt it
return dataset_id
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = self._extract_dependencies_from_workflow_graph(graph)
return dependencies
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == BuiltinNodeTypes.TOOL:
tool_config = node_data.get("tool_configurations", {})
provider_type = tool_config.get("provider_type")
provider_name = tool_config.get("provider")
if provider_type and provider_name:
dependencies.append(f"{provider_name}/{provider_name}")
elif data_type == BuiltinNodeTypes.AGENT:
agent_parameters = node_data.get("agent_parameters", {})
tools = agent_parameters.get("tools", {}).get("value", [])
for tool in tools:
provider_type = tool.get("provider_type")
provider_name = tool.get("provider")
if provider_type and provider_name:
dependencies.append(f"{provider_name}/{provider_name}")
return dependencies

View File

@ -0,0 +1,421 @@
"""
Service for generating snippet workflow executions.
Uses an adapter pattern to bridge CustomizedSnippet with the App-based
WorkflowAppGenerator. The adapter (_SnippetAsApp) provides the minimal App-like
interface needed by the generator, avoiding modifications to core workflow
infrastructure.
Key invariants:
- Snippets always run as WORKFLOW mode (not CHAT or ADVANCED_CHAT).
- The adapter maps snippet.id to app_id in workflow execution records.
- Snippet debugging has no rate limiting (max_active_requests = 0).
Supported execution modes:
- Full workflow run (generate): Runs the entire draft workflow as SSE stream.
- Single node run (run_draft_node): Synchronous single-step debugging for regular nodes.
- Single iteration run (generate_single_iteration): SSE stream for iteration container nodes.
- Single loop run (generate_single_loop): SSE stream for loop container nodes.
"""
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Union
from sqlalchemy.orm import make_transient
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from factories import file_factory
from graphon.file.models import File
from models import Account
from models.model import AppMode, EndUser
from models.snippet import CustomizedSnippet
from models.workflow import Workflow, WorkflowNodeExecutionModel
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class _SnippetAsApp:
"""
Minimal adapter that wraps a CustomizedSnippet to satisfy the App-like
interface required by WorkflowAppGenerator, WorkflowAppConfigManager,
and WorkflowService.run_draft_workflow_node.
Used properties:
- id: maps to snippet.id (stored as app_id in workflows table)
- tenant_id: maps to snippet.tenant_id
- mode: hardcoded to AppMode.WORKFLOW since snippets always run as workflows
- max_active_requests: defaults to 0 (no limit) for snippet debugging
- app_model_config_id: None (snippets don't have app model configs)
"""
id: str
tenant_id: str
mode: str
max_active_requests: int
app_model_config_id: str | None
def __init__(self, snippet: CustomizedSnippet) -> None:
self.id = snippet.id
self.tenant_id = snippet.tenant_id
self.mode = AppMode.WORKFLOW.value
self.max_active_requests = 0
self.app_model_config_id = None
class SnippetGenerateService:
"""
Service for running snippet workflow executions.
Adapts CustomizedSnippet to work with the existing App-based
WorkflowAppGenerator infrastructure, avoiding duplication of the
complex workflow execution pipeline.
"""
# Specific ID for the injected virtual Start node so it can be recognised
_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
@classmethod
def generate(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a snippet's draft workflow.
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
then delegates execution to WorkflowAppGenerator.
If the workflow graph has no Start node, a virtual Start node is injected
in-memory so that:
1. Graph validation passes (root node must have execution_type=ROOT).
2. User inputs are processed into the variable pool by the StartNode logic.
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param args: Workflow inputs (must include "inputs" key)
:param invoke_from: Source of invocation (typically DEBUGGER)
:param streaming: Whether to stream the response
:return: Blocking response mapping or SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
# Inject a virtual Start node when the graph doesn't have one.
workflow = cls._ensure_start_node(workflow, snippet)
# Adapt snippet to App-like interface for WorkflowAppGenerator
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
)
)
@classmethod
def run_published(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
) -> Mapping[str, Any]:
"""
Run a snippet's published workflow in non-streaming (blocking) mode.
Similar to :meth:`generate` but targets the published workflow instead
of the draft, and returns the raw blocking response without SSE
wrapping. Designed for programmatic callers such as evaluation runners.
:param snippet: CustomizedSnippet instance (must be published)
:param user: Account or EndUser initiating the run
:param args: Workflow inputs (must include "inputs" key)
:param invoke_from: Source of invocation
:return: Blocking response mapping with workflow outputs
:raises ValueError: If the snippet has no published workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet)
if not workflow:
raise ValueError("No published workflow found for snippet")
# Inject a virtual Start node when the graph doesn't have one.
workflow = cls._ensure_start_node(workflow, snippet)
app_proxy = _SnippetAsApp(snippet)
response: Mapping[str, Any] = WorkflowAppGenerator().generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
)
return response
@classmethod
def ensure_start_node_for_worker(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
"""Public wrapper for worker-thread start-node injection."""
return cls._ensure_start_node(workflow, snippet)
@classmethod
def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
"""
Return *workflow* with a Start node.
If the graph already contains a Start node, the original workflow is
returned unchanged. Otherwise a virtual Start node is injected and the
workflow object is detached from the SQLAlchemy session so the in-memory
change is never flushed to the database.
"""
graph_dict = workflow.graph_dict
nodes: list[dict[str, Any]] = graph_dict.get("nodes", [])
has_start = any(node.get("data", {}).get("type") == "start" for node in nodes)
if has_start:
return workflow
modified_graph = cls._inject_virtual_start_node(
graph_dict=graph_dict,
input_fields=snippet.input_fields_list,
)
# Detach from session to prevent accidental DB persistence of the
# modified graph. All attributes remain accessible for read.
make_transient(workflow)
workflow.graph = json.dumps(modified_graph)
return workflow
@classmethod
def _inject_virtual_start_node(
cls,
graph_dict: Mapping[str, Any],
input_fields: list[dict[str, Any]],
) -> dict[str, Any]:
"""
Build a new graph dict with a virtual Start node prepended.
The virtual Start node is wired to every existing node that has no
incoming edges (i.e. the current root candidates). This guarantees:
:param graph_dict: Original graph configuration.
:param input_fields: Snippet input field definitions from
``CustomizedSnippet.input_fields_list``.
:return: New graph dict containing the virtual Start node and edges.
"""
nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", []))
edges: list[dict[str, Any]] = list(graph_dict.get("edges", []))
# Identify nodes with no incoming edges.
nodes_with_incoming: set[str] = set()
for edge in edges:
target = edge.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming]
# Build Start node ``variables`` from snippet input fields.
start_variables: list[dict[str, Any]] = []
for field in input_fields:
var: dict[str, Any] = {
"variable": field.get("variable", ""),
"label": field.get("label", field.get("variable", "")),
"type": field.get("type", "text-input"),
"required": field.get("required", False),
"options": field.get("options", []),
}
if field.get("max_length") is not None:
var["max_length"] = field["max_length"]
start_variables.append(var)
virtual_start_node: dict[str, Any] = {
"id": cls._VIRTUAL_START_NODE_ID,
"data": {
"type": "start",
"title": "Start",
"variables": start_variables,
},
}
# Create edges from virtual Start to each root candidate.
new_edges: list[dict[str, Any]] = [
{
"source": cls._VIRTUAL_START_NODE_ID,
"sourceHandle": "source",
"target": root_id,
"targetHandle": "target",
}
for root_id in root_candidate_ids
]
return {
**graph_dict,
"nodes": [virtual_start_node, *nodes],
"edges": [*edges, *new_edges],
}
@classmethod
def run_draft_node(
cls,
snippet: CustomizedSnippet,
node_id: str,
user_inputs: Mapping[str, Any],
account: Account,
query: str = "",
files: Sequence[File] | None = None,
) -> WorkflowNodeExecutionModel:
"""
Run a single node in a snippet's draft workflow (single-step debugging).
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
parses file inputs, then delegates to WorkflowService.run_draft_workflow_node.
:param snippet: CustomizedSnippet instance
:param node_id: ID of the node to run
:param user_inputs: User input values for the node
:param account: Account initiating the run
:param query: Optional query string
:param files: Optional parsed file objects
:return: WorkflowNodeExecutionModel with execution results
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
workflow_service = WorkflowService()
return workflow_service.run_draft_workflow_node(
app_model=app_proxy, # type: ignore[arg-type]
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=account,
query=query,
files=files,
)
@classmethod
def generate_single_iteration(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
node_id: str,
args: Mapping[str, Any],
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a single iteration node in a snippet's draft workflow.
Iteration nodes are container nodes that execute their sub-graph multiple
times, producing many events. Therefore, this uses the full WorkflowAppGenerator
pipeline with SSE streaming (unlike regular single-step node run).
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param node_id: ID of the iteration node to run
:param args: Dict containing 'inputs' key with iteration input data
:param streaming: Whether to stream the response (should be True)
:return: SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
@classmethod
def generate_single_loop(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
node_id: str,
args: Any,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a single loop node in a snippet's draft workflow.
Loop nodes are container nodes that execute their sub-graph repeatedly,
producing many events. Therefore, this uses the full WorkflowAppGenerator
pipeline with SSE streaming (unlike regular single-step node run).
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param node_id: ID of the loop node to run
:param args: Pydantic model with 'inputs' attribute containing loop input data
:param streaming: Whether to stream the response (should be True)
:return: SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
node_id=node_id,
user=user,
args=args, # type: ignore[arg-type]
streaming=streaming,
)
)
@staticmethod
def parse_files(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
"""
Parse file mappings into File objects based on workflow configuration.
:param workflow: Workflow instance for file upload config
:param files: Raw file mapping dicts
:return: Parsed File objects
"""
files = files or []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config is None:
return []
return file_factory.build_from_mappings(
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
)

View File

@ -0,0 +1,608 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, NodeType
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import WorkflowRunTriggeredFrom
from models.snippet import CustomizedSnippet, SnippetType
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowRun,
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import WorkflowHashNotEqualError
logger = logging.getLogger(__name__)
# Node types not allowed in snippet workflows (sync, publish, DSL import).
SNIPPET_FORBIDDEN_NODE_TYPES: frozenset[str] = frozenset(
{
BuiltinNodeTypes.START,
BuiltinNodeTypes.HUMAN_INPUT,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
}
)
class SnippetService:
"""Service for managing customized snippets."""
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize SnippetService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@staticmethod
def validate_snippet_graph_forbidden_nodes(graph: Mapping[str, Any]) -> None:
"""Reject graphs that contain node types not allowed in snippets."""
nodes = graph.get("nodes") or []
disallowed: list[tuple[str, str]] = []
for node in nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data") or {}
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type in SNIPPET_FORBIDDEN_NODE_TYPES:
node_id = node.get("id")
disallowed.append((str(node_id) if node_id is not None else "?", node_type))
if not disallowed:
return
detail = ", ".join(f"{nid}:{t}" for nid, t in disallowed)
raise ValueError(
"Snippet workflow cannot contain start, human-input, or knowledge-retrieval nodes. "
f"Found: {detail}"
)
# --- CRUD Operations ---
@staticmethod
def get_snippets(
*,
tenant_id: str,
page: int = 1,
limit: int = 20,
keyword: str | None = None,
is_published: bool | None = None,
creators: list[str] | None = None,
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
"""
Get paginated list of snippets with optional search.
:param tenant_id: Tenant ID
:param page: Page number (1-indexed)
:param limit: Number of items per page
:param keyword: Optional search keyword for name/description
:param is_published: Optional filter by published status (True/False/None for all)
:param creators: Optional filter by creator account IDs
:return: Tuple of (snippets list, total count, has_more flag)
"""
stmt = (
select(CustomizedSnippet)
.where(CustomizedSnippet.tenant_id == tenant_id)
.order_by(CustomizedSnippet.created_at.desc())
)
if keyword:
stmt = stmt.where(
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
)
if is_published is not None:
stmt = stmt.where(CustomizedSnippet.is_published == is_published)
if creators:
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = db.session.scalar(count_stmt) or 0
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(db.session.scalars(stmt).all())
has_more = len(snippets) > limit
if has_more:
snippets = snippets[:-1]
return snippets, total, has_more
@staticmethod
def get_snippet_by_id(
*,
snippet_id: str,
tenant_id: str,
) -> CustomizedSnippet | None:
"""
Get snippet by ID with tenant isolation.
:param snippet_id: Snippet ID
:param tenant_id: Tenant ID
:return: CustomizedSnippet or None
"""
return (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.id == snippet_id,
CustomizedSnippet.tenant_id == tenant_id,
)
.first()
)
@staticmethod
def create_snippet(
*,
tenant_id: str,
name: str,
description: str | None,
snippet_type: SnippetType,
icon_info: dict | None,
input_fields: list[dict] | None,
account: Account,
) -> CustomizedSnippet:
"""
Create a new snippet.
:param tenant_id: Tenant ID
:param name: Snippet name (must be unique per tenant)
:param description: Snippet description
:param snippet_type: Type of snippet (node or group)
:param icon_info: Icon information
:param input_fields: Input field definitions
:param account: Creator account
:return: Created CustomizedSnippet
:raises ValueError: If name already exists
"""
# Check if name already exists for this tenant
existing = (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == tenant_id,
CustomizedSnippet.name == name,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{name}' already exists")
snippet = CustomizedSnippet(
tenant_id=tenant_id,
name=name,
description=description or "",
type=snippet_type.value,
icon_info=icon_info,
input_fields=json.dumps(input_fields) if input_fields else None,
created_by=account.id,
)
db.session.add(snippet)
db.session.commit()
return snippet
@staticmethod
def update_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
account_id: str,
data: dict,
) -> CustomizedSnippet:
"""
Update snippet attributes.
:param session: Database session
:param snippet: Snippet to update
:param account_id: ID of account making the update
:param data: Dictionary of fields to update
:return: Updated CustomizedSnippet
"""
if "name" in data:
# Check if new name already exists for this tenant
existing = (
session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == snippet.tenant_id,
CustomizedSnippet.name == data["name"],
CustomizedSnippet.id != snippet.id,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{data['name']}' already exists")
snippet.name = data["name"]
if "description" in data:
snippet.description = data["description"]
if "icon_info" in data:
snippet.icon_info = data["icon_info"]
snippet.updated_by = account_id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.add(snippet)
return snippet
@staticmethod
def delete_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
) -> bool:
"""
Delete a snippet.
:param session: Database session
:param snippet: Snippet to delete
:return: True if deleted successfully
"""
session.delete(snippet)
return True
# --- Workflow Operations ---
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get draft workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Draft Workflow or None
"""
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version == "draft",
)
.first()
)
return workflow
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get published workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Published Workflow or None
"""
if not snippet.workflow_id:
return None
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.id == snippet.workflow_id,
)
.first()
)
return workflow
def sync_draft_workflow(
self,
*,
snippet: CustomizedSnippet,
graph: dict,
unique_hash: str | None,
account: Account,
input_fields: list[dict] | None = None,
) -> Workflow:
"""
Sync draft workflow for snippet.
Snippet workflows do not persist environment variables (always empty) or
conversation variables (always empty).
:param snippet: CustomizedSnippet instance
:param graph: Workflow graph configuration
:param unique_hash: Hash for conflict detection
:param account: Account making the change
:param input_fields: Input fields for snippet
:return: Synced Workflow
:raises WorkflowHashNotEqualError: If hash mismatch
"""
SnippetService.validate_snippet_graph_forbidden_nodes(graph)
workflow = self.get_draft_workflow(snippet=snippet)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# Create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
features="{}",
type=WorkflowType.SNIPPET.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
)
db.session.add(workflow)
db.session.flush()
else:
# Update existing draft workflow
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = []
workflow.conversation_variables = []
# Update snippet's input_fields if provided
if input_fields is not None:
snippet.input_fields = json.dumps(input_fields)
snippet.updated_by = account.id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return workflow
def publish_workflow(
self,
*,
session: Session,
snippet: CustomizedSnippet,
account: Account,
) -> Workflow:
"""
Publish the draft workflow as a new version.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param account: Account making the change
:return: Published Workflow
:raises ValueError: If no draft workflow exists
"""
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
SnippetService.validate_snippet_graph_forbidden_nodes(draft_workflow.graph_dict)
# Create new published workflow
workflow = Workflow.new(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
)
session.add(workflow)
# Update snippet version
snippet.version += 1
snippet.is_published = True
snippet.workflow_id = workflow.id
snippet.updated_by = account.id
session.add(snippet)
return workflow
def get_all_published_workflows(
self,
*,
session: Session,
snippet: CustomizedSnippet,
page: int,
limit: int,
) -> tuple[Sequence[Workflow], bool]:
"""
Get all published workflow versions for snippet.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param page: Page number
:param limit: Items per page
:return: Tuple of (workflows list, has_more flag)
"""
if not snippet.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version != "draft",
)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
workflows = list(session.scalars(stmt).all())
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
# --- Default Block Configs ---
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configurations for all node types.
:return: List of default configurations
"""
default_block_configs: list[dict[str, Any]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(dict(default_config))
return default_block_configs
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
"""
Get default config for specific node type.
:param node_type: Node type string
:param filters: Optional filters
:return: Default configuration or None
"""
node_type_enum = NodeType(node_type)
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
# --- Workflow Run Operations ---
def get_snippet_workflow_runs(
self,
*,
snippet: CustomizedSnippet,
args: dict,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs for snippet.
:param snippet: CustomizedSnippet instance
:param args: Request arguments (last_id, limit)
:return: InfiniteScrollPagination result
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
triggered_from_values = [
WorkflowRunTriggeredFrom.DEBUGGING,
]
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
triggered_from=triggered_from_values,
limit=limit,
last_id=last_id,
)
def get_snippet_workflow_run(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> WorkflowRun | None:
"""
Get workflow run details.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: WorkflowRun or None
"""
return self._workflow_run_repo.get_workflow_run_by_id(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
run_id=run_id,
)
def get_snippet_workflow_run_node_executions(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get workflow run node execution list.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: List of WorkflowNodeExecutionModel
"""
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
return []
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_run_id=workflow_run.id,
)
return node_executions
# --- Node Execution Operations ---
def get_snippet_node_last_run(
self,
*,
snippet: CustomizedSnippet,
workflow: Workflow,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node in a snippet workflow.
:param snippet: CustomizedSnippet instance
:param workflow: Workflow instance
:param node_id: Node identifier
:return: WorkflowNodeExecutionModel or None
"""
return self._node_execution_service_repo.get_node_last_execution(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_id=workflow.id,
node_id=node_id,
)
# --- Use Count ---
@staticmethod
def increment_use_count(
*,
session: Session,
snippet: CustomizedSnippet,
) -> None:
"""
Increment the use_count when snippet is used.
:param session: Database session
:param snippet: CustomizedSnippet instance
"""
snippet.use_count += 1
session.add(snippet)

View File

@ -38,6 +38,7 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@ -819,9 +820,9 @@ class WebhookService:
user_id=None,
)
# consume quota before triggering workflow execution
# reserve quota before triggering workflow execution
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
@ -832,11 +833,16 @@ class WebhookService:
raise
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
try:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -23,17 +23,28 @@ class LogView:
"""Lightweight wrapper for WorkflowAppLog with computed details.
- Exposes `details_` for marshalling to `details` in API response
- Exposes `evaluation_` for marshalling evaluation metrics in API response
- Proxies all other attributes to the underlying `WorkflowAppLog`
"""
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
def __init__(
self,
log: WorkflowAppLog,
details: LogViewDetails | None,
evaluation: list[dict] | None = None,
):
self.log = log
self.details_ = details
self.evaluation_ = evaluation
@property
def details(self) -> LogViewDetails | None:
return self.details_
@property
def evaluation(self) -> list[dict] | None:
return self.evaluation_
def __getattr__(self, name):
return getattr(self.log, name)
@ -170,12 +181,20 @@ class WorkflowAppService:
# Execute query and get items
if detail:
rows = session.execute(offset_stmt).all()
items = [
LogView(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
logs_with_details = [
(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
for log, meta_val in rows
]
else:
items = [LogView(log, None) for log in session.scalars(offset_stmt).all()]
logs_with_details = [(log, None) for log in session.scalars(offset_stmt).all()]
workflow_run_ids = [log.workflow_run_id for log, _ in logs_with_details]
eval_map = self._batch_query_evaluation_metrics(session, workflow_run_ids)
items = [
LogView(log, details, evaluation=eval_map.get(log.workflow_run_id))
for log, details in logs_with_details
]
return {
"page": page,
"limit": limit,
@ -257,6 +276,45 @@ class WorkflowAppService:
"data": items,
}
@staticmethod
def _batch_query_evaluation_metrics(
session: Session,
workflow_run_ids: list[str],
) -> dict[str, list[dict[str, Any]]]:
"""Return evaluation metrics keyed by workflow_run_id.
Only returns metrics from completed evaluation runs. If a workflow
run was not part of any evaluation (or the evaluation has not
completed), it will be absent from the result dict.
"""
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
if not workflow_run_ids:
return {}
non_null_ids = [wid for wid in workflow_run_ids if wid]
if not non_null_ids:
return {}
stmt = (
select(EvaluationRunItem.workflow_run_id, EvaluationRunItem.metrics)
.join(EvaluationRun, EvaluationRun.id == EvaluationRunItem.evaluation_run_id)
.where(
EvaluationRunItem.workflow_run_id.in_(non_null_ids),
EvaluationRun.status == EvaluationRunStatus.COMPLETED,
)
)
rows = session.execute(stmt).all()
result: dict[str, list[dict[str, Any]]] = {}
for wf_run_id, metrics_json in rows:
if wf_run_id and metrics_json:
parsed: list[dict[str, Any]] = json.loads(metrics_json)
existing = result.get(wf_run_id, [])
existing.extend(parsed)
result[wf_run_id] = existing
return result
def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:

View File

@ -1,7 +1,7 @@
import dataclasses
import json
import logging
from collections.abc import Mapping, Sequence
from collections.abc import Mapping, Sequence, Set
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import StrEnum
@ -271,12 +271,20 @@ class WorkflowDraftVariableService:
)
def list_variables_without_values(
self, app_id: str, page: int, limit: int, user_id: str
self,
app_id: str,
page: int,
limit: int,
user_id: str,
*,
exclude_node_ids: Set[str] | None = None,
) -> WorkflowDraftVariableList:
criteria = [
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
]
if exclude_node_ids:
criteria.append(WorkflowDraftVariable.node_id.notin_(list(exclude_node_ids)))
total = None
base_stmt = select(WorkflowDraftVariable).where(*criteria)
if page == 1:

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, cast
from graphon.entities import WorkflowNodeExecution
from graphon.entities import GraphInitParams, WorkflowNodeExecution
from graphon.entities.graph_config import NodeConfigDict
from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import (
@ -30,7 +30,7 @@ from graphon.variable_loader import load_into_variable_pool
from graphon.variables import VariableBase
from graphon.variables.input_entities import VariableEntityType
from graphon.variables.variables import Variable
from sqlalchemy import exists, select
from sqlalchemy import and_, exists, or_, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -42,7 +42,7 @@ from core.entities import PluginCredentialType
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type
from core.workflow.human_input_compat import (
DeliveryChannelConfig,
normalize_human_input_node_data_for_graph,
@ -65,6 +65,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from libs.helper import escape_like_pattern
from models import Account
from models.human_input import HumanInputFormRecipient, RecipientType
from models.model import App, AppMode
@ -99,6 +100,15 @@ class WorkflowService:
Workflow Service
"""
# Centralized unsupported node types for evaluation workflow publishing.
# Keep this set updated when evaluation workflow constraints change.
EVALUATION_UNSUPPORTED_NODE_TYPES: frozenset[str] = frozenset(
{
BuiltinNodeTypes.HUMAN_INPUT,
*TRIGGER_NODE_TYPES,
}
)
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize WorkflowService with repository dependencies."""
if session_maker is None:
@ -237,6 +247,59 @@ class WorkflowService:
return workflows, has_more
def list_published_evaluation_workflows(
self,
*,
session: Session,
tenant_id: str,
page: int,
limit: int,
user_id: str | None,
named_only: bool = False,
keyword: str | None = None,
) -> tuple[Sequence[Workflow], bool]:
"""
List published evaluation-type workflows for a tenant (cross-app), excluding draft rows.
When ``keyword`` is non-empty, match workflows whose marked name or parent app name contains
the substring (case-insensitive, LIKE wildcards escaped).
"""
stmt = select(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.type == WorkflowType.EVALUATION,
Workflow.version != Workflow.VERSION_DRAFT,
)
if user_id:
stmt = stmt.where(Workflow.created_by == user_id)
if named_only:
stmt = stmt.where(Workflow.marked_name != "")
keyword_stripped = keyword.strip() if keyword else ""
if keyword_stripped:
escaped = escape_like_pattern(keyword_stripped)
pattern = f"%{escaped}%"
stmt = stmt.join(
App,
and_(Workflow.app_id == App.id, App.tenant_id == tenant_id),
).where(
or_(
Workflow.marked_name.ilike(pattern, escape="\\"),
App.name.ilike(pattern, escape="\\"),
)
)
stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit)
workflows = session.scalars(stmt).all()
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
def sync_draft_workflow(
self,
*,
@ -400,6 +463,127 @@ class WorkflowService:
# return new workflow
return workflow
def publish_evaluation_workflow(
self,
*,
session: Session,
app_model: App,
account: Account,
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
"""Publish draft workflow as an evaluation workflow version.
Compared to standard publish:
- force published workflow type to ``evaluation``;
- reject graphs containing trigger or human-input nodes.
"""
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
# Validate credentials before publishing, for credential policy check
from services.feature_service import FeatureService
if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)
# validate graph structure
self.validate_graph_structure(graph=draft_workflow.graph_dict)
self._validate_evaluation_workflow_nodes(draft_workflow)
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=WorkflowType.EVALUATION.value,
version=Workflow.version_from_datetime(naive_utc_now()),
graph=draft_workflow.graph,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
features=draft_workflow.features,
)
session.add(workflow)
# trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
return workflow
def convert_published_workflow_type(
self,
*,
session: Session,
app_model: App,
target_type: WorkflowType,
account: Account,
) -> Workflow:
"""
Convert a published workflow type in-place.
This endpoint only supports conversion between standard workflow and evaluation workflow.
"""
if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}:
raise ValueError("target_type must be either 'workflow' or 'evaluation'")
if not app_model.workflow_id:
raise WorkflowNotFoundError("Published workflow not found")
stmt = select(Workflow).where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
)
workflow = session.scalar(stmt)
if not workflow:
raise WorkflowNotFoundError("Published workflow not found")
if workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.")
if workflow.type == target_type:
return workflow
if target_type == WorkflowType.EVALUATION:
self._validate_evaluation_workflow_nodes(workflow)
workflow.type = target_type
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
return workflow
@staticmethod
def _validate_evaluation_workflow_nodes(workflow: Workflow) -> None:
"""Ensure evaluation workflows do not contain unsupported node types."""
disallowed_nodes: list[tuple[str, str]] = []
for node_id, node_data in workflow.walk_nodes():
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type in WorkflowService.EVALUATION_UNSUPPORTED_NODE_TYPES:
disallowed_nodes.append((node_id, node_type))
if not disallowed_nodes:
return
formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes)
raise ValueError(
"Evaluation workflow cannot contain trigger or human-input nodes. "
f"Found disallowed nodes: {formatted_nodes}"
)
def _validate_workflow_credentials(self, workflow: Workflow) -> None:
"""
Validate all credentials in workflow nodes before publishing.
@ -1550,8 +1734,8 @@ def _setup_variable_pool(
"workflow_execution_id": str(uuid.uuid4()),
}
# Only add chatflow-specific variables for non-workflow types.
if workflow.type != WorkflowType.WORKFLOW:
# Only add chatflow-specific variables for chat-like workflow types.
if workflow.type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}:
system_variable_values.update(
{
"query": query,

View File

@ -0,0 +1,541 @@
import io
import json
import logging
from typing import Any
from celery import shared_task
from openpyxl import Workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationCategory,
EvaluationDatasetInput,
EvaluationItemResult,
EvaluationRunData,
NodeInfo,
)
from core.evaluation.entities.judgment_entity import JudgmentConfig
from core.evaluation.evaluation_manager import EvaluationManager
from core.evaluation.judgment.processor import JudgmentProcessor
from core.evaluation.runners.agent_evaluation_runner import AgentEvaluationRunner
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from core.evaluation.runners.llm_evaluation_runner import LLMEvaluationRunner
from core.evaluation.runners.retrieval_evaluation_runner import RetrievalEvaluationRunner
from core.evaluation.runners.snippet_evaluation_runner import SnippetEvaluationRunner
from core.evaluation.runners.workflow_evaluation_runner import WorkflowEvaluationRunner
from extensions.ext_database import db
from graphon.node_events import NodeRunResult
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
from models.model import UploadFile
from services.evaluation_service import EvaluationService
logger = logging.getLogger(__name__)
@shared_task(queue="evaluation")
def run_evaluation(run_data_dict: dict[str, Any]) -> None:
"""Celery task for running evaluations asynchronously.
Workflow:
1. Deserialize EvaluationRunData
2. Execute target and collect node results
3. Evaluate metrics via runners (one per metric-node pair)
4. Merge results per test-data row (1 item = 1 EvaluationRunItem)
5. Apply judgment conditions
6. Persist results + generate result XLSX
7. Update EvaluationRun status to COMPLETED
"""
run_data = EvaluationRunData.model_validate(run_data_dict)
with db.engine.connect() as connection:
from sqlalchemy.orm import Session
session = Session(bind=connection)
try:
_execute_evaluation(session, run_data)
except Exception as e:
logger.exception("Evaluation run %s failed", run_data.evaluation_run_id)
_mark_run_failed(session, run_data.evaluation_run_id, str(e))
finally:
session.close()
def _execute_evaluation(session: Any, run_data: EvaluationRunData) -> None:
"""Core evaluation execution logic."""
evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first()
if not evaluation_run:
logger.error("EvaluationRun %s not found", run_data.evaluation_run_id)
return
if evaluation_run.status == EvaluationRunStatus.CANCELLED:
logger.info("EvaluationRun %s was cancelled", run_data.evaluation_run_id)
return
evaluation_instance = EvaluationManager.get_evaluation_instance()
if evaluation_instance is None:
raise ValueError("Evaluation framework not configured")
# Mark as running
evaluation_run.status = EvaluationRunStatus.RUNNING
evaluation_run.started_at = naive_utc_now()
session.commit()
if run_data.target_type == "dataset":
results: list[EvaluationItemResult] = _execute_retrieval_test(
session=session,
evaluation_run=evaluation_run,
run_data=run_data,
evaluation_instance=evaluation_instance,
)
else:
evaluation_service = EvaluationService()
node_run_result_mapping_list, workflow_run_ids = evaluation_service.execute_targets(
tenant_id=run_data.tenant_id,
target_type=run_data.target_type,
target_id=run_data.target_id,
input_list=run_data.input_list,
)
workflow_run_id_map = {
item.index: wf_run_id
for item, wf_run_id in zip(run_data.input_list, workflow_run_ids)
if wf_run_id
}
results = _execute_evaluation_runner(
session=session,
run_data=run_data,
evaluation_instance=evaluation_instance,
node_run_result_mapping_list=node_run_result_mapping_list,
workflow_run_id_map=workflow_run_id_map,
)
# Compute summary metrics
metrics_summary = _compute_metrics_summary(results, run_data.judgment_config)
# Generate result XLSX
result_xlsx = _generate_result_xlsx(run_data.input_list, results)
# Store result file
result_file_id = _store_result_file(run_data.tenant_id, run_data.evaluation_run_id, result_xlsx, session)
# Update run to completed
evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first()
if evaluation_run:
evaluation_run.status = EvaluationRunStatus.COMPLETED
evaluation_run.completed_at = naive_utc_now()
evaluation_run.metrics_summary = json.dumps(metrics_summary)
if result_file_id:
evaluation_run.result_file_id = result_file_id
session.commit()
logger.info("Evaluation run %s completed successfully", run_data.evaluation_run_id)
# ---------------------------------------------------------------------------
# Evaluation orchestration — merge + judgment + persist
# ---------------------------------------------------------------------------
def _execute_evaluation_runner(
session: Any,
run_data: EvaluationRunData,
evaluation_instance: BaseEvaluationInstance,
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
workflow_run_id_map: dict[int, str] | None = None,
) -> list[EvaluationItemResult]:
"""Evaluate all metrics, merge per-item, apply judgment, persist once.
Ensures 1 test-data row = 1 EvaluationRunItem with all metrics combined.
"""
results_by_index: dict[int, EvaluationItemResult] = {}
# Phase 1: Default metrics — one batch per (metric, node) pair
for default_metric in run_data.default_metrics:
for node_info in default_metric.node_info_list:
node_run_result_list: list[NodeRunResult] = []
item_indices: list[int] = []
for i, mapping in enumerate(node_run_result_mapping_list):
node_result = mapping.get(node_info.node_id)
if node_result is not None:
node_run_result_list.append(node_result)
item_indices.append(i)
if not node_run_result_list:
continue
runner = _create_runner(EvaluationCategory(node_info.type), evaluation_instance)
try:
evaluated = runner.evaluate_metrics(
node_run_result_list=node_run_result_list,
default_metric=default_metric,
model_provider=run_data.evaluation_model_provider,
model_name=run_data.evaluation_model,
tenant_id=run_data.tenant_id,
)
except Exception:
logger.exception(
"Failed metrics for %s on node %s", default_metric.metric, node_info.node_id
)
continue
_stamp_and_merge(evaluated, item_indices, node_info, results_by_index)
# Phase 2: Customized metrics
if run_data.customized_metrics:
try:
customized_results = evaluation_instance.evaluate_with_customized_workflow(
node_run_result_mapping_list=node_run_result_mapping_list,
customized_metrics=run_data.customized_metrics,
tenant_id=run_data.tenant_id,
)
for result in customized_results:
_merge_result(results_by_index, result.index, result)
except Exception:
logger.exception("Failed customized metrics for run %s", run_data.evaluation_run_id)
results = list(results_by_index.values())
# Phase 3: Judgment
if run_data.judgment_config:
results = _apply_judgment(results, run_data.judgment_config)
# Phase 4: Persist — one EvaluationRunItem per test-data row
_persist_results(
session, run_data.evaluation_run_id, results, run_data.input_list, workflow_run_id_map
)
return results
def _execute_retrieval_test(
session: Any,
evaluation_run: EvaluationRun,
run_data: EvaluationRunData,
evaluation_instance: BaseEvaluationInstance,
) -> list[EvaluationItemResult]:
"""Execute knowledge base retrieval for all items, then evaluate metrics.
Unlike the workflow-based path, there are no workflow nodes to traverse.
Hit testing is run directly for each dataset item and the results are fed
straight into :class:`RetrievalEvaluationRunner`.
"""
node_run_result_list = EvaluationService.execute_retrieval_test_targets(
dataset_id=run_data.target_id,
account_id=evaluation_run.created_by,
input_list=run_data.input_list,
)
results_by_index: dict[int, EvaluationItemResult] = {}
runner = RetrievalEvaluationRunner(evaluation_instance)
for default_metric in run_data.default_metrics:
try:
evaluated = runner.evaluate_metrics(
node_run_result_list=node_run_result_list,
default_metric=default_metric,
model_provider=run_data.evaluation_model_provider,
model_name=run_data.evaluation_model,
tenant_id=run_data.tenant_id,
)
item_indices = list(range(len(node_run_result_list)))
_stamp_and_merge(evaluated, item_indices, None, results_by_index)
except Exception:
logger.exception("Failed retrieval metrics for run %s", run_data.evaluation_run_id)
results = list(results_by_index.values())
if run_data.judgment_config:
results = _apply_judgment(results, run_data.judgment_config)
_persist_results(session, run_data.evaluation_run_id, results, run_data.input_list)
return results
# ---------------------------------------------------------------------------
# Helpers — merge, judgment, persist
# ---------------------------------------------------------------------------
def _stamp_and_merge(
evaluated: list[EvaluationItemResult],
item_indices: list[int],
node_info: NodeInfo | None,
results_by_index: dict[int, EvaluationItemResult],
) -> None:
"""Attach node_info to each metric and merge into results_by_index."""
for result in evaluated:
original_index = item_indices[result.index]
if node_info is not None:
for metric in result.metrics:
metric.node_info = node_info
_merge_result(results_by_index, original_index, result)
def _merge_result(
results_by_index: dict[int, EvaluationItemResult],
index: int,
new_result: EvaluationItemResult,
) -> None:
"""Merge new metrics into an existing result for the same index."""
existing = results_by_index.get(index)
if existing:
merged_metrics = existing.metrics + new_result.metrics
actual = existing.actual_output or new_result.actual_output
results_by_index[index] = existing.model_copy(
update={"metrics": merged_metrics, "actual_output": actual}
)
else:
results_by_index[index] = new_result.model_copy(update={"index": index})
def _apply_judgment(
results: list[EvaluationItemResult],
judgment_config: JudgmentConfig,
) -> list[EvaluationItemResult]:
"""Evaluate pass/fail judgment conditions on each result's metrics."""
judged: list[EvaluationItemResult] = []
for result in results:
if result.error is not None or not result.metrics:
judged.append(result)
continue
metric_values: dict[tuple[str, str], object] = {
(m.node_info.node_id, m.name): m.value for m in result.metrics if m.node_info
}
judgment_result = JudgmentProcessor.evaluate(metric_values, judgment_config)
judged.append(result.model_copy(update={"judgment": judgment_result}))
return judged
def _persist_results(
session: Any,
evaluation_run_id: str,
results: list[EvaluationItemResult],
input_list: list[EvaluationDatasetInput],
workflow_run_id_map: dict[int, str] | None = None,
) -> None:
"""Persist evaluation results — one EvaluationRunItem per test-data row."""
dataset_map = {item.index: item for item in input_list}
wf_map = workflow_run_id_map or {}
for result in results:
item_input = dataset_map.get(result.index)
run_item = EvaluationRunItem(
evaluation_run_id=evaluation_run_id,
workflow_run_id=wf_map.get(result.index),
item_index=result.index,
inputs=json.dumps(item_input.inputs) if item_input else None,
expected_output=item_input.expected_output if item_input else None,
actual_output=result.actual_output,
metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None,
judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None,
metadata_json=json.dumps(result.metadata) if result.metadata else None,
error=result.error,
overall_score=getattr(result, "overall_score", None),
)
session.add(run_item)
session.commit()
def _create_runner(
category: EvaluationCategory,
evaluation_instance: BaseEvaluationInstance,
) -> BaseEvaluationRunner:
"""Create the appropriate runner for the evaluation category."""
match category:
case EvaluationCategory.LLM:
return LLMEvaluationRunner(evaluation_instance)
case EvaluationCategory.RETRIEVAL | EvaluationCategory.KNOWLEDGE_BASE:
return RetrievalEvaluationRunner(evaluation_instance)
case EvaluationCategory.AGENT:
return AgentEvaluationRunner(evaluation_instance)
case EvaluationCategory.WORKFLOW:
return WorkflowEvaluationRunner(evaluation_instance)
case EvaluationCategory.SNIPPET:
return SnippetEvaluationRunner(evaluation_instance)
case _:
raise ValueError(f"Unknown evaluation category: {category}")
# ---------------------------------------------------------------------------
# Status / summary / XLSX / storage helpers (unchanged logic)
# ---------------------------------------------------------------------------
def _mark_run_failed(session: Any, run_id: str, error: str) -> None:
"""Mark an evaluation run as failed."""
try:
evaluation_run = session.query(EvaluationRun).filter_by(id=run_id).first()
if evaluation_run:
evaluation_run.status = EvaluationRunStatus.FAILED
evaluation_run.error = error[:2000]
evaluation_run.completed_at = naive_utc_now()
session.commit()
except Exception:
logger.exception("Failed to mark run %s as failed", run_id)
def _compute_metrics_summary(
results: list[EvaluationItemResult],
judgment_config: JudgmentConfig | None,
) -> dict[str, Any]:
"""Compute aggregate metric and judgment summaries for an evaluation run."""
summary: dict[str, Any] = {}
if judgment_config is not None and judgment_config.conditions:
evaluated_results: list[EvaluationItemResult] = [
result for result in results if result.error is None and result.metrics
]
passed_items = sum(1 for result in evaluated_results if result.judgment.passed)
evaluated_items = len(evaluated_results)
summary["_judgment"] = {
"enabled": True,
"logical_operator": judgment_config.logical_operator,
"configured_conditions": len(judgment_config.conditions),
"evaluated_items": evaluated_items,
"passed_items": passed_items,
"failed_items": evaluated_items - passed_items,
"pass_rate": passed_items / evaluated_items if evaluated_items else 0.0,
}
return summary
def _generate_result_xlsx(
input_list: list[EvaluationDatasetInput],
results: list[EvaluationItemResult],
) -> bytes:
"""Generate result XLSX with input data, actual output, metric scores, and judgment."""
wb = Workbook()
ws = wb.active
if ws is None:
ws = wb.create_sheet("Evaluation Results")
ws.title = "Evaluation Results"
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
# Collect all metric names
all_metric_names: list[str] = []
for result in results:
for metric in result.metrics:
if metric.name not in all_metric_names:
all_metric_names.append(metric.name)
# Collect all input keys
input_keys: list[str] = []
for item in input_list:
for key in item.inputs:
if key not in input_keys:
input_keys.append(key)
has_judgment = any(bool(r.judgment.condition_results) for r in results)
judgment_headers = ["judgment"] if has_judgment else []
headers = (
["index"] + input_keys + ["expected_output", "actual_output"] + all_metric_names + judgment_headers + ["error"]
)
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
ws.column_dimensions["A"].width = 10
for col_idx in range(2, len(headers) + 1):
ws.column_dimensions[get_column_letter(col_idx)].width = 25
result_by_index = {r.index: r for r in results}
for row_idx, item in enumerate(input_list, start=2):
result = result_by_index.get(item.index)
col = 1
ws.cell(row=row_idx, column=col, value=item.index).border = thin_border
col += 1
for key in input_keys:
val = item.inputs.get(key, "")
ws.cell(row=row_idx, column=col, value=str(val)).border = thin_border
col += 1
ws.cell(row=row_idx, column=col, value=item.expected_output or "").border = thin_border
col += 1
ws.cell(row=row_idx, column=col, value=result.actual_output if result else "").border = thin_border
col += 1
metric_scores = {m.name: m.value for m in result.metrics} if result else {}
for metric_name in all_metric_names:
score = metric_scores.get(metric_name)
ws.cell(row=row_idx, column=col, value=score if score is not None else "").border = thin_border
col += 1
if has_judgment:
if result and result.judgment.condition_results:
judgment_value = "Pass" if result.judgment.passed else "Fail"
else:
judgment_value = ""
ws.cell(row=row_idx, column=col, value=judgment_value).border = thin_border
col += 1
ws.cell(row=row_idx, column=col, value=result.error if result else "").border = thin_border
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue()
def _store_result_file(
tenant_id: str,
run_id: str,
xlsx_content: bytes,
session: Any,
) -> str | None:
"""Store result XLSX file and return the UploadFile ID."""
try:
from extensions.ext_storage import storage
from libs.uuid_utils import uuidv7
filename = f"evaluation-result-{run_id[:8]}.xlsx"
storage_key = f"evaluation_results/{tenant_id}/{str(uuidv7())}.xlsx"
storage.save(storage_key, xlsx_content)
upload_file: UploadFile = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=storage_key,
name=filename,
size=len(xlsx_content),
extension="xlsx",
mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="system",
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file)
session.commit()
return upload_file.id
except Exception:
logger.exception("Failed to store result file for run %s", run_id)
return None

View File

@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@ -298,10 +299,10 @@ def dispatch_triggered_workflow(
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
# consume quota before invoking trigger
# reserve quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
@ -387,6 +388,7 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",

View File

@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()

View File

@ -36,12 +36,19 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
}
# Setup default mock returns for workflow service
@ -101,6 +108,8 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@ -118,6 +127,7 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@ -465,6 +475,7 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@ -478,8 +489,10 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from enums import quota_type
from services import quota_service
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

@ -0,0 +1,145 @@
"""Unit tests for metric-based judgment evaluation."""
from core.evaluation.entities.judgment_entity import JudgmentCondition, JudgmentConfig
from core.evaluation.judgment.processor import JudgmentProcessor
def test_evaluate_uses_and_conditions_against_metric_values() -> None:
"""All conditions must pass when the logical operator is ``and``."""
config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_node_1", "faithfulness"],
comparison_operator=">",
value="0.8",
),
JudgmentCondition(
variable_selector=["llm_node_1", "answer_relevancy"],
comparison_operator="",
value="0.7",
),
],
)
result = JudgmentProcessor.evaluate(
{
("llm_node_1", "faithfulness"): 0.9,
("llm_node_1", "answer_relevancy"): 0.75,
},
config,
)
assert result.passed is True
assert len(result.condition_results) == 2
assert all(condition_result.passed for condition_result in result.condition_results)
def test_evaluate_sets_passed_false_when_any_and_condition_fails() -> None:
"""A failed metric comparison should make the overall judgment fail."""
config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_node_1", "faithfulness"],
comparison_operator=">",
value="0.8",
),
JudgmentCondition(
variable_selector=["llm_node_1", "answer_relevancy"],
comparison_operator="",
value="0.7",
),
],
)
result = JudgmentProcessor.evaluate(
{
("llm_node_1", "faithfulness"): 0.9,
("llm_node_1", "answer_relevancy"): 0.6,
},
config,
)
assert result.passed is False
assert result.condition_results[-1].passed is False
def test_evaluate_with_different_nodes_same_metric() -> None:
"""Conditions can target different nodes even with the same metric name."""
config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_node_1", "faithfulness"],
comparison_operator=">",
value="0.8",
),
JudgmentCondition(
variable_selector=["llm_node_2", "faithfulness"],
comparison_operator=">",
value="0.5",
),
],
)
result = JudgmentProcessor.evaluate(
{
("llm_node_1", "faithfulness"): 0.9,
("llm_node_2", "faithfulness"): 0.6,
},
config,
)
assert result.passed is True
assert len(result.condition_results) == 2
def test_evaluate_or_operator_passes_when_one_condition_met() -> None:
"""With ``or`` logical operator, one passing condition should suffice."""
config = JudgmentConfig(
logical_operator="or",
conditions=[
JudgmentCondition(
variable_selector=["node_a", "score"],
comparison_operator=">",
value="0.9",
),
JudgmentCondition(
variable_selector=["node_b", "score"],
comparison_operator=">",
value="0.5",
),
],
)
result = JudgmentProcessor.evaluate(
{
("node_a", "score"): 0.3,
("node_b", "score"): 0.8,
},
config,
)
assert result.passed is True
def test_evaluate_string_contains_operator() -> None:
"""String operators should work correctly via workflow engine delegation."""
config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["node_a", "status"],
comparison_operator="contains",
value="success",
),
],
)
result = JudgmentProcessor.evaluate(
{("node_a", "status"): "evaluation_success_done"},
config,
)
assert result.passed is True

View File

@ -0,0 +1,78 @@
"""Tests for judgment application logic (moved from BaseEvaluationRunner to evaluation_task)."""
from core.evaluation.entities.evaluation_entity import EvaluationItemResult, EvaluationMetric, NodeInfo
from core.evaluation.entities.judgment_entity import JudgmentCondition, JudgmentConfig
from tasks.evaluation_task import _apply_judgment
_NODE_INFO = NodeInfo(node_id="llm_1", type="llm", title="LLM Node")
def test_apply_judgment_marks_passing_result() -> None:
"""Items whose metrics satisfy the judgment conditions should be marked as passed."""
results = [
EvaluationItemResult(
index=0,
actual_output="result",
metrics=[EvaluationMetric(name="faithfulness", value=0.91, node_info=_NODE_INFO)],
)
]
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_1", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
judged = _apply_judgment(results, judgment_config)
assert judged[0].judgment.passed is True
def test_apply_judgment_marks_failing_result() -> None:
"""Items whose metrics do NOT satisfy the conditions should be marked as failed."""
results = [
EvaluationItemResult(
index=0,
metrics=[EvaluationMetric(name="faithfulness", value=0.5, node_info=_NODE_INFO)],
)
]
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_1", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
judged = _apply_judgment(results, judgment_config)
assert judged[0].judgment.passed is False
def test_apply_judgment_skips_errored_items() -> None:
"""Items with errors should be passed through without judgment evaluation."""
results = [
EvaluationItemResult(index=0, error="timeout"),
]
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_1", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
judged = _apply_judgment(results, judgment_config)
assert judged[0].error == "timeout"
assert judged[0].judgment.passed is False

View File

View File

@ -0,0 +1,349 @@
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
from unittest.mock import patch
import pytest
from enums.quota_type import QuotaType
from services.quota_service import QuotaCharge, QuotaService, unlimited
class TestQuotaType:
def test_billing_key_trigger(self):
assert QuotaType.TRIGGER.billing_key == "trigger_event"
def test_billing_key_workflow(self):
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
def test_billing_key_unlimited_raises(self):
with pytest.raises(ValueError, match="Invalid quota type"):
_ = QuotaType.UNLIMITED.billing_key
class TestQuotaService:
def test_reserve_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService"),
):
mock_cfg.BILLING_ENABLED = False
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_reserve_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
def test_reserve_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
assert charge.success is True
assert charge.charge_id == "rid-1"
assert charge._tenant_id == "t1"
assert charge._feature_key == "trigger_event"
assert charge._amount == 1
mock_bs.quota_reserve.assert_called_once()
def test_reserve_no_reservation_id_raises(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {}
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_quota_exceeded_propagates(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_api_exception_returns_unlimited(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = RuntimeError("network")
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_consume_calls_reserve_and_commit(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
mock_bs.quota_commit.return_value = {}
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
assert charge.success is True
mock_bs.quota_commit.assert_called_once()
def test_check_billing_disabled(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = False
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_check_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
def test_check_sufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=100),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
def test_check_insufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=5),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
def test_check_unlimited_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=-1),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
def test_check_exception_returns_true(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_release_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = False
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_empty_reservation(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.return_value = {}
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_called_once_with(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
)
def test_release_exception_swallowed(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.side_effect = RuntimeError("fail")
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_get_remaining_normal(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
def test_get_remaining_unlimited(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_over_limit_returns_zero(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_exception_returns_neg1(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.side_effect = RuntimeError
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_empty_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_non_dict_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = "invalid"
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_feature_not_in_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
assert remaining == 0
def test_get_remaining_non_dict_feature_info(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
class TestQuotaCharge:
def test_commit_success(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
mock_bs.quota_commit.assert_called_once_with(
tenant_id="t1",
feature_key="trigger_event",
reservation_id="rid-1",
actual_amount=1,
)
assert charge._committed is True
def test_commit_with_actual_amount(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=10,
)
charge.commit(actual_amount=5)
call_kwargs = mock_bs.quota_commit.call_args[1]
assert call_kwargs["actual_amount"] == 5
def test_commit_idempotent(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
charge.commit()
assert mock_bs.quota_commit.call_count == 1
def test_commit_no_charge_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_no_tenant_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
_feature_key="trigger_event",
)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_exception_swallowed(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.side_effect = RuntimeError("fail")
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
def test_refund_success(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
)
charge.refund()
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_refund_no_charge_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.refund()
mock_rel.assert_not_called()
def test_refund_no_tenant_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
)
charge.refund()
mock_rel.assert_not_called()
class TestUnlimited:
def test_unlimited_returns_success_with_no_charge_id(self):
charge = unlimited()
assert charge.success is True
assert charge.charge_id is None
assert charge._quota_type == QuotaType.UNLIMITED

View File

@ -23,6 +23,7 @@ import pytest
import services.app_generate_service as ags_module
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.quota_type import QuotaType
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
@ -447,8 +448,8 @@ class TestGenerateBilling:
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
consume_mock = mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
reserve_mock = mocker.patch(
"services.app_generate_service.QuotaService.reserve",
return_value=quota_charge,
)
mocker.patch(
@ -467,7 +468,8 @@ class TestGenerateBilling:
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
consume_mock.assert_called_once_with("tenant-id")
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
quota_charge.commit.assert_called_once()
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
from services.errors.app import QuotaExceededError
@ -475,7 +477,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
"services.app_generate_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
)
@ -492,7 +494,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
"services.app_generate_service.QuotaService.reserve",
return_value=quota_charge,
)
mocker.patch(

View File

@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
- repo: SQLAlchemyWorkflowTriggerLogRepository
- dispatcher_manager_class: QueueDispatcherManager class
- dispatcher: dispatcher instance
- quota_workflow: QuotaType.WORKFLOW
- quota_service: QuotaService mock
- get_workflow: AsyncWorkflowService._get_workflow method
- professional_task: execute_workflow_professional
- team_task: execute_workflow_team
@ -72,7 +72,12 @@ class TestAsyncWorkflowService:
mock_repo.create.side_effect = _create_side_effect
mock_dispatcher = MagicMock()
quota_workflow = MagicMock()
mock_quota_service = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
mock_team_task = MagicMock()
mock_sandbox_task = MagicMock()
with (
patch.object(
@ -88,8 +93,8 @@ class TestAsyncWorkflowService:
) as mock_get_workflow,
patch.object(
async_workflow_service_module,
"QuotaType",
new=SimpleNamespace(WORKFLOW=quota_workflow),
"QuotaService",
new=mock_quota_service,
),
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
@ -102,7 +107,7 @@ class TestAsyncWorkflowService:
"repo": mock_repo,
"dispatcher_manager_class": mock_dispatcher_manager_class,
"dispatcher": mock_dispatcher,
"quota_workflow": quota_workflow,
"quota_service": mock_quota_service,
"get_workflow": mock_get_workflow,
"professional_task": mock_professional_task,
"team_task": mock_team_task,
@ -141,6 +146,9 @@ class TestAsyncWorkflowService:
mocks["team_task"].delay.return_value = task_result
mocks["sandbox_task"].delay.return_value = task_result
quota_charge_mock = MagicMock()
mocks["quota_service"].reserve.return_value = quota_charge_mock
class DummyAccount:
def __init__(self, user_id: str):
self.id = user_id
@ -158,7 +166,8 @@ class TestAsyncWorkflowService:
assert result.status == "queued"
assert result.queue == queue_name
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
assert session.commit.call_count == 2
created_log = mocks["repo"].create.call_args[0][0]
@ -245,7 +254,7 @@ class TestAsyncWorkflowService:
mocks = async_workflow_trigger_mocks
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
mocks["get_workflow"].return_value = workflow
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
feature="workflow",
tenant_id="tenant-123",
required=1,

View File

@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation:
yield mock
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
"""Test retrieval of tenant feature plan usage information."""
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation:
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
def test_get_quota_info(self, mock_send_request):
"""Test retrieval of quota info from new endpoint."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
mock_send_request.return_value = expected_response
# Act
result = BillingService.get_quota_info(tenant_id)
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
"""Test updating tenant feature usage with positive delta (adding credits)."""
# Arrange
@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation:
)
class TestBillingServiceQuotaOperations:
"""Unit tests for quota reserve/commit/release operations."""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_quota_reserve_success(self, mock_send_request):
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
mock_send_request.return_value = expected
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/reserve",
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
)
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
assert result["available"] == 99
assert isinstance(result["available"], int)
assert result["reserved"] == 1
assert isinstance(result["reserved"], int)
def test_quota_reserve_with_meta(self, mock_send_request):
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
meta = {"source": "webhook"}
BillingService.quota_reserve(
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"source": "webhook"}
def test_quota_commit_success(self, mock_send_request):
expected = {"available": 98, "reserved": 0, "refunded": 0}
mock_send_request.return_value = expected
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/commit",
json={
"tenant_id": "t1",
"feature_key": "trigger_event",
"reservation_id": "rid-1",
"actual_amount": 1,
},
)
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
)
assert result["available"] == 97
assert isinstance(result["available"], int)
assert result["refunded"] == 1
assert isinstance(result["refunded"], int)
def test_quota_commit_with_meta(self, mock_send_request):
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
meta = {"reason": "partial"}
BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"reason": "partial"}
def test_quota_release_success(self, mock_send_request):
expected = {"available": 100, "reserved": 0, "released": 1}
mock_send_request.return_value = expected
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/release",
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
)
def test_quota_release_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
assert result["available"] == 100
assert isinstance(result["available"], int)
assert result["released"] == 1
assert isinstance(result["released"], int)
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
mock_send_request.return_value = {
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
}
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert isinstance(result["trigger_event"]["usage"], int)
assert result["trigger_event"]["limit"] == 3000
assert isinstance(result["trigger_event"]["limit"], int)
assert result["trigger_event"]["reset_date"] == 1700000000
assert isinstance(result["trigger_event"]["reset_date"], int)
assert result["api_rate_limit"]["limit"] == -1
assert isinstance(result["api_rate_limit"]["limit"], int)
def test_get_quota_info_accepts_int_values(self, mock_send_request):
"""Test that get_quota_info works with native int values."""
expected = {
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
"api_rate_limit": {"usage": 0, "limit": -1},
}
mock_send_request.return_value = expected
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert result["trigger_event"]["limit"] == 3000
assert result["api_rate_limit"]["limit"] == -1
class TestBillingServiceRateLimitEnforcement:
"""Unit tests for rate limit enforcement mechanisms.

View File

@ -559,3 +559,772 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# === Merged from test_webhook_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1", is_debug=True
)
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
# Assert
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
# Act / Assert
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
# Arrange
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = MagicMock()
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
# Assert
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
# Act
result = WebhookService._detect_binary_mimetype(b"binary")
# Assert
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
# Act
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
# Assert
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
# Act
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
# Assert
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
raw_value: str,
param_type: str,
expected: Any,
) -> None:
# Arrange
# Act
result = WebhookService._convert_form_value("param", raw_value, param_type)
# Assert
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
# Act
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
# Assert
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
# Arrange
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters() -> None:
# Arrange
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
# Act
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
# Assert
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
# Arrange
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
# Act
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
# Assert
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
# Arrange
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act
WebhookService._validate_required_headers(headers, configs)
# Assert
assert True
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
# Arrange
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
# Arrange
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
# Act
result = WebhookService._validate_http_metadata(webhook_data, node_data)
# Assert
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
# Arrange
headers = {"content-type": "application/json; charset=utf-8"}
# Act
result = WebhookService._extract_content_type(headers)
# Assert
assert result == "application/json"
def test_build_workflow_inputs_should_include_expected_keys() -> None:
# Arrange
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
# Act
result = WebhookService.build_workflow_inputs(webhook_data)
# Assert
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
# Act
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Assert
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
monkeypatch.setattr(
service_module.QuotaService,
"reserve",
MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)),
)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
# Act / Assert
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
# Act / Assert
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
# Act / Assert
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
# Act / Assert
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert logger_exception_mock.call_count == 1
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
# Arrange
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
# Act
body, status = WebhookService.generate_webhook_response(node_config)
# Assert
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
# Arrange
# Act
webhook_id = WebhookService.generate_webhook_id()
# Assert
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
# Arrange
# Act
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
# Assert
assert result == 123

View File

@ -0,0 +1,97 @@
"""Unit tests for evaluation task helpers."""
from core.evaluation.entities.evaluation_entity import EvaluationItemResult, EvaluationMetric, NodeInfo
from core.evaluation.entities.judgment_entity import (
JudgmentCondition,
JudgmentConfig,
JudgmentResult,
)
from tasks.evaluation_task import _compute_metrics_summary, _merge_result, _stamp_and_merge
_NODE_INFO = NodeInfo(node_id="llm_1", type="llm", title="LLM Node")
def test_compute_metrics_summary_includes_judgment_counts() -> None:
"""Summary should expose pass/fail counts when judgment rules are configured."""
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["llm_1", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
results = [
EvaluationItemResult(
index=0,
metrics=[EvaluationMetric(name="faithfulness", value=0.9, node_info=_NODE_INFO)],
judgment=JudgmentResult(passed=True, logical_operator="and", condition_results=[]),
),
EvaluationItemResult(
index=1,
metrics=[EvaluationMetric(name="faithfulness", value=0.4, node_info=_NODE_INFO)],
judgment=JudgmentResult(passed=False, logical_operator="and", condition_results=[]),
),
EvaluationItemResult(index=2, error="timeout"),
]
summary = _compute_metrics_summary(results, judgment_config)
assert summary["_judgment"] == {
"enabled": True,
"logical_operator": "and",
"configured_conditions": 1,
"evaluated_items": 2,
"passed_items": 1,
"failed_items": 1,
"pass_rate": 0.5,
}
def test_merge_result_combines_metrics_for_same_index() -> None:
"""Merging two results with the same index should concatenate their metrics."""
results_by_index: dict[int, EvaluationItemResult] = {}
first = EvaluationItemResult(
index=0,
actual_output="output_1",
metrics=[EvaluationMetric(name="faithfulness", value=0.9)],
)
_merge_result(results_by_index, 0, first)
second = EvaluationItemResult(
index=0,
actual_output="output_2",
metrics=[EvaluationMetric(name="context_precision", value=0.7)],
)
_merge_result(results_by_index, 0, second)
merged = results_by_index[0]
assert len(merged.metrics) == 2
assert merged.metrics[0].name == "faithfulness"
assert merged.metrics[1].name == "context_precision"
assert merged.actual_output == "output_1"
def test_stamp_and_merge_attaches_node_info() -> None:
"""_stamp_and_merge should set node_info on every metric and remap indices."""
results_by_index: dict[int, EvaluationItemResult] = {}
node_info = NodeInfo(node_id="llm_1", type="llm", title="GPT-4")
evaluated = [
EvaluationItemResult(
index=0,
metrics=[EvaluationMetric(name="faithfulness", value=0.85)],
)
]
item_indices = [3]
_stamp_and_merge(evaluated, item_indices, node_info, results_by_index)
assert 3 in results_by_index
metric = results_by_index[3].metrics[0]
assert metric.node_info is not None
assert metric.node_info.node_id == "llm_1"
assert metric.node_info.type == "llm"

494
api/uv.lock generated
View File

@ -320,6 +320,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" },
]
[[package]]
name = "appdirs"
version = "1.4.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470, upload-time = "2020-05-11T07:59:51.037Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566, upload-time = "2020-05-11T07:59:49.499Z" },
]
[[package]]
name = "apscheduler"
version = "3.11.2"
@ -1243,6 +1252,45 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/27/b822b474aaefb684d11df358d52e012699a2a8af231f9b47c54b73f280cb/databricks_sdk-0.73.0-py3-none-any.whl", hash = "sha256:a4d3cfd19357a2b459d2dc3101454d7f0d1b62865ce099c35d0c342b66ac64ff", size = 753896, upload-time = "2025-11-05T06:52:56.451Z" },
]
[[package]]
name = "dataclasses-json"
version = "0.6.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "marshmallow" },
{ name = "typing-inspect" },
]
sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" },
]
[[package]]
name = "datasets"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "dill" },
{ name = "filelock" },
{ name = "fsspec", extra = ["http"] },
{ name = "huggingface-hub" },
{ name = "multiprocess" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pandas" },
{ name = "pyarrow" },
{ name = "pyarrow-hotfix" },
{ name = "pyyaml" },
{ name = "requests" },
{ name = "tqdm" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/52/e7/6ee66732f74e4fb1c8915e58b3c253aded777ad0fa457f3f831dd0cd09b4/datasets-2.19.2.tar.gz", hash = "sha256:eccb82fb3bb5ee26ccc6d7a15b7f1f834e2cc4e59b7cff7733a003552bad51ef", size = 2215337, upload-time = "2024-06-03T05:11:44.756Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/59/46818ebeb708234a60e42ccf409d20709e482519d2aa450b501ddbba4594/datasets-2.19.2-py3-none-any.whl", hash = "sha256:e07ff15d75b1af75c87dd96323ba2a361128d495136652f37fd62f918d17bb4e", size = 542113, upload-time = "2024-06-03T05:11:41.151Z" },
]
[[package]]
name = "dateparser"
version = "1.2.2"
@ -1267,6 +1315,45 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" },
]
[[package]]
name = "deepeval"
version = "3.9.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "click" },
{ name = "grpcio" },
{ name = "jinja2" },
{ name = "nest-asyncio" },
{ name = "openai" },
{ name = "opentelemetry-api" },
{ name = "opentelemetry-sdk" },
{ name = "portalocker" },
{ name = "posthog" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pyfiglet" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-repeat" },
{ name = "pytest-rerunfailures" },
{ name = "pytest-xdist" },
{ name = "python-dotenv" },
{ name = "requests" },
{ name = "rich" },
{ name = "sentry-sdk" },
{ name = "setuptools" },
{ name = "tabulate" },
{ name = "tenacity" },
{ name = "tqdm" },
{ name = "typer" },
{ name = "wheel" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/64/dd6c1bce14324c211d72facd6d4deb1453321d7e32a82b7d5f1ead7ba817/deepeval-3.9.6.tar.gz", hash = "sha256:07eaad40f17a21b809bd6719f9f4ae61d8880753edddaee2abbcd4a27cfb225d", size = 614512, upload-time = "2026-04-07T15:27:41.533Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/3e/cac5dad55d9a9540b5288dbe853554638302df43c459e0f64663d5ad129a/deepeval-3.9.6-py3-none-any.whl", hash = "sha256:6a368cdd7fda7bf94ef1bcb65b40914271096dbdc54aa8988aa62bda41c68f94", size = 843386, upload-time = "2026-04-07T15:27:39.134Z" },
]
[[package]]
name = "defusedxml"
version = "0.7.1"
@ -1459,6 +1546,10 @@ dev = [
{ name = "types-ujson" },
{ name = "xinference-client" },
]
evaluation = [
{ name = "deepeval" },
{ name = "ragas" },
]
storage = [
{ name = "azure-storage-blob" },
{ name = "bce-python-sdk" },
@ -1756,6 +1847,10 @@ dev = [
{ name = "types-ujson", specifier = ">=5.10.0" },
{ name = "xinference-client", specifier = "~=2.4.0" },
]
evaluation = [
{ name = "deepeval", specifier = ">=2.0.0" },
{ name = "ragas", specifier = ">=0.2.0" },
]
storage = [
{ name = "azure-storage-blob", specifier = "==12.28.0" },
{ name = "bce-python-sdk", specifier = "~=0.9.69" },
@ -2167,6 +2262,24 @@ dependencies = [
[package.metadata]
requires-dist = [{ name = "weaviate-client", specifier = "==4.20.5" }]
[[package]]
name = "dill"
version = "0.3.8"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" },
]
[[package]]
name = "diskcache"
version = "5.6.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" },
]
[[package]]
name = "diskcache-weave"
version = "5.6.3.post1"
@ -2546,11 +2659,16 @@ wheels = [
[[package]]
name = "fsspec"
version = "2025.10.0"
version = "2024.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/24/7f/2747c0d332b9acfa75dc84447a066fdf812b5a6b8d30472b74d309bfe8cb/fsspec-2025.10.0.tar.gz", hash = "sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59", size = 309285, upload-time = "2025-10-30T14:58:44.036Z" }
sdist = { url = "https://files.pythonhosted.org/packages/8b/b8/e3ba21f03c00c27adc9a8cd1cab8adfb37b6024757133924a9a4eab63a83/fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9", size = 170742, upload-time = "2024-03-18T19:35:13.995Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/eb/02/a6b21098b1d5d6249b7c5ab69dde30108a71e4e819d4a9778f1de1d5b70d/fsspec-2025.10.0-py3-none-any.whl", hash = "sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d", size = 200966, upload-time = "2025-10-30T14:58:42.53Z" },
{ url = "https://files.pythonhosted.org/packages/93/6d/66d48b03460768f523da62a57a7e14e5e95fdf339d79e996ce3cecda2cdb/fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512", size = 171991, upload-time = "2024-03-18T19:35:11.259Z" },
]
[package.optional-dependencies]
http = [
{ name = "aiohttp" },
]
[[package]]
@ -3333,6 +3451,28 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/ca/1172b6638d52f2d6caa2dd262ec4c811ba59eee96d54a7701930726bce18/installer-0.7.0-py3-none-any.whl", hash = "sha256:05d1933f0a5ba7d8d6296bb6d5018e7c94fa473ceb10cf198a92ccea19c27b53", size = 453838, upload-time = "2023-03-17T20:39:36.219Z" },
]
[[package]]
name = "instructor"
version = "1.15.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "docstring-parser" },
{ name = "jinja2" },
{ name = "jiter" },
{ name = "openai" },
{ name = "pydantic" },
{ name = "pydantic-core" },
{ name = "requests" },
{ name = "rich" },
{ name = "tenacity" },
{ name = "typer" },
]
sdist = { url = "https://files.pythonhosted.org/packages/dc/a4/832cfb15420360e26d2d85bd9d5fe1e4b839d52587574d389bc31284bf6f/instructor-1.15.1.tar.gz", hash = "sha256:c72406469d9025b742e83cf0c13e914b317db2089d08d889944e74fcd659ef94", size = 69948370, upload-time = "2026-04-03T01:51:30.107Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d8/c8/36c5d9b80aaf40ba9a7084a8fc18c967db6bf248a4cc8d0f0816b14284be/instructor-1.15.1-py3-none-any.whl", hash = "sha256:be81d17ba2b154a04ab4720808f24f9d6b598f80992f82eaf9cc79006099cf6c", size = 178156, upload-time = "2026-04-03T01:51:23.098Z" },
]
[[package]]
name = "intersystems-irispython"
version = "5.3.2"
@ -3442,6 +3582,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e1/03/7afcecb4242d93b684708b47fb014abdc1922a01b38c0e30f1117ae74a83/json_repair-0.59.2-py3-none-any.whl", hash = "sha256:6ca6238519c24f671bcb05d1f38a0d6a452bb4ca5af82137595c5c2f1a0fb785", size = 46918, upload-time = "2026-04-11T15:55:39.817Z" },
]
[[package]]
name = "jsonpatch"
version = "1.33"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jsonpointer" },
]
sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" },
]
[[package]]
name = "jsonpointer"
version = "3.1.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/18/c7/af399a2e7a67fd18d63c40c5e62d3af4e67b836a2107468b6a5ea24c4304/jsonpointer-3.1.1.tar.gz", hash = "sha256:0b801c7db33a904024f6004d526dcc53bbb8a4a0f4e32bfd10beadf60adf1900", size = 9068, upload-time = "2026-03-23T22:32:32.458Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/6a/a83720e953b1682d2d109d3c2dbb0bc9bf28cc1cbc205be4ef4be5da709d/jsonpointer-3.1.1-py3-none-any.whl", hash = "sha256:8ff8b95779d071ba472cf5bc913028df06031797532f08a7d5b602d8b2a488ca", size = 7659, upload-time = "2026-03-23T22:32:31.568Z" },
]
[[package]]
name = "jsonschema"
version = "4.25.1"
@ -3515,6 +3676,106 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/89/43/d9bebfc3db7dea6ec80df5cb2aad8d274dd18ec2edd6c4f21f32c237cbbb/kubernetes-33.1.0-py2.py3-none-any.whl", hash = "sha256:544de42b24b64287f7e0aa9513c93cb503f7f40eea39b20f66810011a86eabc5", size = 1941335, upload-time = "2025-06-09T21:57:56.327Z" },
]
[[package]]
name = "langchain"
version = "1.2.15"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph" },
{ name = "pydantic" },
]
sdist = { url = "https://files.pythonhosted.org/packages/98/3f/888a7099d2bd2917f8b0c3ffc7e347f1e664cf64267820b0b923c4f339fc/langchain-1.2.15.tar.gz", hash = "sha256:1717b6719daefae90b2728314a5e2a117ff916291e2862595b6c3d6fba33d652", size = 574732, upload-time = "2026-04-03T14:26:03.994Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/e8/a3b8cb0005553f6a876865073c81ef93bd7c5b18381bcb9ba4013af96ebc/langchain-1.2.15-py3-none-any.whl", hash = "sha256:e349db349cb3e9550c4044077cf90a1717691756cc236438404b23500e615874", size = 112714, upload-time = "2026-04-03T14:26:02.557Z" },
]
[[package]]
name = "langchain-classic"
version = "1.0.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langchain-text-splitters" },
{ name = "langsmith" },
{ name = "pydantic" },
{ name = "pyyaml" },
{ name = "requests" },
{ name = "sqlalchemy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/32/04/b01c09e37414bab9f209efa311502841a3c0de5bc6c35e729c8d8a9893c9/langchain_classic-1.0.3.tar.gz", hash = "sha256:168ef1dfbfb18cae5a9ff0accecc9413a5b5aa3464b53fa841561a3384b6324a", size = 10534933, upload-time = "2026-03-13T13:56:11.96Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/e6/cfdeedec0537ffbf5041773590d25beb7f2aa467cc6630e788c9c7c72c3e/langchain_classic-1.0.3-py3-none-any.whl", hash = "sha256:26df1ec9806b1fbff19d9085a747ea7d8d82d7e3fb1d25132859979de627ef79", size = 1041335, upload-time = "2026-03-13T13:56:09.677Z" },
]
[[package]]
name = "langchain-community"
version = "0.4.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "dataclasses-json" },
{ name = "httpx-sse" },
{ name = "langchain-classic" },
{ name = "langchain-core" },
{ name = "langsmith" },
{ name = "numpy" },
{ name = "pydantic-settings" },
{ name = "pyyaml" },
{ name = "requests" },
{ name = "sqlalchemy" },
{ name = "tenacity" },
]
sdist = { url = "https://files.pythonhosted.org/packages/53/97/a03585d42b9bdb6fbd935282d6e3348b10322a24e6ce12d0c99eb461d9af/langchain_community-0.4.1.tar.gz", hash = "sha256:f3b211832728ee89f169ddce8579b80a085222ddb4f4ed445a46e977d17b1e85", size = 33241144, upload-time = "2025-10-27T15:20:32.504Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f0/a4/c4fde67f193401512337456cabc2148f2c43316e445f5decd9f8806e2992/langchain_community-0.4.1-py3-none-any.whl", hash = "sha256:2135abb2c7748a35c84613108f7ebf30f8505b18c3c18305ffaecfc7651f6c6a", size = 2533285, upload-time = "2025-10-27T15:20:30.767Z" },
]
[[package]]
name = "langchain-core"
version = "1.2.28"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jsonpatch" },
{ name = "langsmith" },
{ name = "packaging" },
{ name = "pydantic" },
{ name = "pyyaml" },
{ name = "tenacity" },
{ name = "typing-extensions" },
{ name = "uuid-utils" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" },
]
[[package]]
name = "langchain-openai"
version = "1.1.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "openai" },
{ name = "tiktoken" },
]
sdist = { url = "https://files.pythonhosted.org/packages/49/ae/1dbeb49ab8f098f78ec52e21627e705e5d7c684dc8826c2c34cc2746233a/langchain_openai-1.1.9.tar.gz", hash = "sha256:fdee25dcf4b0685d8e2f59856f4d5405431ef9e04ab53afe19e2e8360fed8234", size = 1004828, upload-time = "2026-02-10T21:03:21.615Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/52/a1/8a20d19f69d022c10d34afa42d972cc50f971b880d0eb4a828cf3dd824a8/langchain_openai-1.1.9-py3-none-any.whl", hash = "sha256:ca2482b136c45fb67c0db84a9817de675e0eb8fb2203a33914c1b7a96f273940", size = 85769, upload-time = "2026-02-10T21:03:20.333Z" },
]
[[package]]
name = "langchain-text-splitters"
version = "1.1.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
]
sdist = { url = "https://files.pythonhosted.org/packages/85/38/14121ead61e0e75f79c3a35e5148ac7c2fe754a55f76eab3eed573269524/langchain_text_splitters-1.1.1.tar.gz", hash = "sha256:34861abe7c07d9e49d4dc852d0129e26b32738b60a74486853ec9b6d6a8e01d2", size = 279352, upload-time = "2026-02-18T23:02:42.798Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/66/d9e0c3b83b0ad75ee746c51ba347cacecb8d656b96e1d513f3e334d1ccab/langchain_text_splitters-1.1.1-py3-none-any.whl", hash = "sha256:5ed0d7bf314ba925041e7d7d17cd8b10f688300d5415fb26c29442f061e329dc", size = 35734, upload-time = "2026-02-18T23:02:41.913Z" },
]
[[package]]
name = "langdetect"
version = "1.0.9"
@ -3543,6 +3804,62 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/be/0a/b84e3e68a690ccfe6d64953c572772c685fcb0915b7f2ee3a87c22e388ab/langfuse-4.2.0-py3-none-any.whl", hash = "sha256:bfd760bf10fd0228f297f6369436620f76d16b589de46393d65706b27e4e4082", size = 475449, upload-time = "2026-04-10T11:55:23.624Z" },
]
[[package]]
name = "langgraph"
version = "1.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
{ name = "langgraph-prebuilt" },
{ name = "langgraph-sdk" },
{ name = "pydantic" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5c/e5/d3f72ead3c7f15769d5a9c07e373628f1fbaf6cbe7735694d7085859acf6/langgraph-1.1.6.tar.gz", hash = "sha256:1783f764b08a607e9f288dbcf6da61caeb0dd40b337e5c9fb8b412341fbc0b60", size = 549634, upload-time = "2026-04-03T19:01:32.561Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/71/e6/b36ecdb3ff4ba9a290708d514bae89ebbe2f554b6abbe4642acf3fddbe51/langgraph-1.1.6-py3-none-any.whl", hash = "sha256:fdbf5f54fa5a5a4c4b09b7b5e537f1b2fa283d2f0f610d3457ddeecb479458b9", size = 169755, upload-time = "2026-04-03T19:01:30.686Z" },
]
[[package]]
name = "langgraph-checkpoint"
version = "4.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "ormsgpack" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/44/a8df45d1e8b4637e29789fa8bae1db022c953cc7ac80093cfc52e923547e/langgraph_checkpoint-4.0.1.tar.gz", hash = "sha256:b433123735df11ade28829e40ce25b9be614930cd50245ff2af60629234befd9", size = 158135, upload-time = "2026-02-27T21:06:16.092Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/4c/09a4a0c42f5d2fc38d6c4d67884788eff7fd2cfdf367fdf7033de908b4c0/langgraph_checkpoint-4.0.1-py3-none-any.whl", hash = "sha256:e3adcd7a0e0166f3b48b8cf508ce0ea366e7420b5a73aa81289888727769b034", size = 50453, upload-time = "2026-02-27T21:06:14.293Z" },
]
[[package]]
name = "langgraph-prebuilt"
version = "1.0.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
]
sdist = { url = "https://files.pythonhosted.org/packages/99/4c/06dac899f4945bedb0c3a1583c19484c2cc894114ea30d9a538dd270086e/langgraph_prebuilt-1.0.9.tar.gz", hash = "sha256:93de7512e9caade4b77ead92428f6215c521fdb71b8ffda8cd55f0ad814e64de", size = 165850, upload-time = "2026-04-03T14:06:37.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1d/a2/8368ac187b75e7f9d938ca075d34f116683f5cfc48d924029ee79aea147b/langgraph_prebuilt-1.0.9-py3-none-any.whl", hash = "sha256:776c8e3154a5aef5ad0e5bf3f263f2dcaab3983786cc20014b7f955d99d2d1b2", size = 35958, upload-time = "2026-04-03T14:06:36.58Z" },
]
[[package]]
name = "langgraph-sdk"
version = "0.3.13"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0e/db/77a45127dddcfea5e4256ba916182903e4c31dc4cfca305b8c386f0a9e53/langgraph_sdk-0.3.13.tar.gz", hash = "sha256:419ca5663eec3cec192ad194ac0647c0c826866b446073eb40f384f950986cd5", size = 196360, upload-time = "2026-04-07T20:34:18.766Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/ef/64d64e9f8eea47ce7b939aa6da6863b674c8d418647813c20111645fcc62/langgraph_sdk-0.3.13-py3-none-any.whl", hash = "sha256:aee09e345c90775f6de9d6f4c7b847cfc652e49055c27a2aed0d981af2af3bd0", size = 96668, upload-time = "2026-04-07T20:34:17.866Z" },
]
[[package]]
name = "langsmith"
version = "0.7.30"
@ -3731,6 +4048,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" },
]
[[package]]
name = "marshmallow"
version = "3.26.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "packaging" },
]
sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" },
]
[[package]]
name = "mdurl"
version = "0.1.2"
@ -3870,6 +4199,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
]
[[package]]
name = "multiprocess"
version = "0.70.16"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dill" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" },
{ url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" },
{ url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" },
{ url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" },
{ url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" },
]
[[package]]
name = "murmurhash"
version = "1.0.15"
@ -3940,6 +4285,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/15/dd/b3250826c29cee7816de4409a2fe5e469a68b9a89f6bfaa5eed74f05532c/mysql_connector_python-9.6.0-py2.py3-none-any.whl", hash = "sha256:44b0fb57207ebc6ae05b5b21b7968a9ed33b29187fe87b38951bad2a334d75d5", size = 480527, upload-time = "2026-02-10T12:04:36.176Z" },
]
[[package]]
name = "nest-asyncio"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
]
[[package]]
name = "networkx"
version = "3.6"
@ -4567,6 +4921,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/3d/b4fefad8bdf91e0fe212eb04975aeb36ea92997269d68857efcc7eb1dda3/orjson-3.11.6-cp312-cp312-win_arm64.whl", hash = "sha256:65dfa096f4e3a5e02834b681f539a87fbe85adc82001383c0db907557f666bfc", size = 135212, upload-time = "2026-01-29T15:12:18.3Z" },
]
[[package]]
name = "ormsgpack"
version = "1.12.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/12/0c/f1761e21486942ab9bb6feaebc610fa074f7c5e496e6962dea5873348077/ormsgpack-1.12.2.tar.gz", hash = "sha256:944a2233640273bee67521795a73cf1e959538e0dfb7ac635505010455e53b33", size = 39031, upload-time = "2026-01-18T20:55:28.023Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4c/36/16c4b1921c308a92cef3bf6663226ae283395aa0ff6e154f925c32e91ff5/ormsgpack-1.12.2-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7a29d09b64b9694b588ff2f80e9826bdceb3a2b91523c5beae1fab27d5c940e7", size = 378618, upload-time = "2026-01-18T20:55:50.835Z" },
{ url = "https://files.pythonhosted.org/packages/c0/68/468de634079615abf66ed13bb5c34ff71da237213f29294363beeeca5306/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b39e629fd2e1c5b2f46f99778450b59454d1f901bc507963168985e79f09c5d", size = 203186, upload-time = "2026-01-18T20:56:11.163Z" },
{ url = "https://files.pythonhosted.org/packages/73/a9/d756e01961442688b7939bacd87ce13bfad7d26ce24f910f6028178b2cc8/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:958dcb270d30a7cb633a45ee62b9444433fa571a752d2ca484efdac07480876e", size = 210738, upload-time = "2026-01-18T20:56:09.181Z" },
{ url = "https://files.pythonhosted.org/packages/7b/ba/795b1036888542c9113269a3f5690ab53dd2258c6fb17676ac4bd44fcf94/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d379d72b6c5e964851c77cfedfb386e474adee4fd39791c2c5d9efb53505cc", size = 212569, upload-time = "2026-01-18T20:56:06.135Z" },
{ url = "https://files.pythonhosted.org/packages/6c/aa/bff73c57497b9e0cba8837c7e4bcab584b1a6dbc91a5dd5526784a5030c8/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8463a3fc5f09832e67bdb0e2fda6d518dc4281b133166146a67f54c08496442e", size = 387166, upload-time = "2026-01-18T20:55:36.738Z" },
{ url = "https://files.pythonhosted.org/packages/d3/cf/f8283cba44bcb7b14f97b6274d449db276b3a86589bdb363169b51bc12de/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:eddffb77eff0bad4e67547d67a130604e7e2dfbb7b0cde0796045be4090f35c6", size = 482498, upload-time = "2026-01-18T20:55:29.626Z" },
{ url = "https://files.pythonhosted.org/packages/05/be/71e37b852d723dfcbe952ad04178c030df60d6b78eba26bfd14c9a40575e/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fcd55e5f6ba0dbce624942adf9f152062135f991a0126064889f68eb850de0dd", size = 425518, upload-time = "2026-01-18T20:55:49.556Z" },
{ url = "https://files.pythonhosted.org/packages/7a/0c/9803aa883d18c7ef197213cd2cbf73ba76472a11fe100fb7dab2884edf48/ormsgpack-1.12.2-cp312-cp312-win_amd64.whl", hash = "sha256:d024b40828f1dde5654faebd0d824f9cc29ad46891f626272dd5bfd7af2333a4", size = 117462, upload-time = "2026-01-18T20:55:47.726Z" },
{ url = "https://files.pythonhosted.org/packages/c8/9e/029e898298b2cc662f10d7a15652a53e3b525b1e7f07e21fef8536a09bb8/ormsgpack-1.12.2-cp312-cp312-win_arm64.whl", hash = "sha256:da538c542bac7d1c8f3f2a937863dba36f013108ce63e55745941dda4b75dbb6", size = 111559, upload-time = "2026-01-18T20:55:54.273Z" },
]
[[package]]
name = "oss2"
version = "2.19.1"
@ -4793,7 +5164,7 @@ wheels = [
[[package]]
name = "posthog"
version = "7.0.1"
version = "5.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backoff" },
@ -4801,11 +5172,10 @@ dependencies = [
{ name = "python-dateutil" },
{ name = "requests" },
{ name = "six" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a2/d4/b9afe855a8a7a1bf4459c28ae4c300b40338122dc850acabefcf2c3df24d/posthog-7.0.1.tar.gz", hash = "sha256:21150562c2630a599c1d7eac94bc5c64eb6f6acbf3ff52ccf1e57345706db05a", size = 126985, upload-time = "2025-11-15T12:44:22.465Z" }
sdist = { url = "https://files.pythonhosted.org/packages/48/20/60ae67bb9d82f00427946218d49e2e7e80fb41c15dc5019482289ec9ce8d/posthog-5.4.0.tar.gz", hash = "sha256:701669261b8d07cdde0276e5bc096b87f9e200e3b9589c5ebff14df658c5893c", size = 88076, upload-time = "2025-06-20T23:19:23.485Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/05/0c/8b6b20b0be71725e6e8a32dcd460cdbf62fe6df9bc656a650150dc98fedd/posthog-7.0.1-py3-none-any.whl", hash = "sha256:efe212d8d88a9ba80a20c588eab4baf4b1a5e90e40b551160a5603bb21e96904", size = 145234, upload-time = "2025-11-15T12:44:21.247Z" },
{ url = "https://files.pythonhosted.org/packages/4f/98/e480cab9a08d1c09b1c59a93dade92c1bb7544826684ff2acbfd10fcfbd4/posthog-5.4.0-py3-none-any.whl", hash = "sha256:284dfa302f64353484420b52d4ad81ff5c2c2d1d607c4e2db602ac72761831bd", size = 105364, upload-time = "2025-06-20T23:19:22.001Z" },
]
[[package]]
@ -5000,6 +5370,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" },
]
[[package]]
name = "pyarrow-hotfix"
version = "0.7"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d2/ed/c3e8677f7abf3981838c2af7b5ac03e3589b3ef94fcb31d575426abae904/pyarrow_hotfix-0.7.tar.gz", hash = "sha256:59399cd58bdd978b2e42816a4183a55c6472d4e33d183351b6069f11ed42661d", size = 9910, upload-time = "2025-04-25T10:17:06.247Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2e/c3/94ade4906a2f88bc935772f59c934013b4205e773bcb4239db114a6da136/pyarrow_hotfix-0.7-py3-none-any.whl", hash = "sha256:3236f3b5f1260f0e2ac070a55c1a7b339c4bb7267839bd2015e283234e758100", size = 7923, upload-time = "2025-04-25T10:17:05.224Z" },
]
[[package]]
name = "pyasn1"
version = "0.6.3"
@ -5120,6 +5499,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" },
]
[[package]]
name = "pyfiglet"
version = "1.0.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c8/e3/0a86276ad2c383ce08d76110a8eec2fe22e7051c4b8ba3fa163a0b08c428/pyfiglet-1.0.4.tar.gz", hash = "sha256:db9c9940ed1bf3048deff534ed52ff2dafbbc2cd7610b17bb5eca1df6d4278ef", size = 1560615, upload-time = "2025-08-15T18:32:47.302Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/5c/fe9f95abd5eaedfa69f31e450f7e2768bef121dbdf25bcddee2cd3087a16/pyfiglet-1.0.4-py3-none-any.whl", hash = "sha256:65b57b7a8e1dff8a67dc8e940a117238661d5e14c3e49121032bd404d9b2b39f", size = 1806118, upload-time = "2025-08-15T18:32:45.556Z" },
]
[[package]]
name = "pygments"
version = "2.20.0"
@ -5329,6 +5717,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
]
[[package]]
name = "pytest-asyncio"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
]
[[package]]
name = "pytest-benchmark"
version = "5.2.3"
@ -5381,6 +5782,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" },
]
[[package]]
name = "pytest-repeat"
version = "0.9.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/80/d4/69e9dbb9b8266df0b157c72be32083403c412990af15c7c15f7a3fd1b142/pytest_repeat-0.9.4.tar.gz", hash = "sha256:d92ac14dfaa6ffcfe6917e5d16f0c9bc82380c135b03c2a5f412d2637f224485", size = 6488, upload-time = "2025-04-07T14:59:53.077Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/73/d4/8b706b81b07b43081bd68a2c0359fe895b74bf664b20aca8005d2bb3be71/pytest_repeat-0.9.4-py3-none-any.whl", hash = "sha256:c1738b4e412a6f3b3b9e0b8b29fcd7a423e50f87381ad9307ef6f5a8601139f3", size = 4180, upload-time = "2025-04-07T14:59:51.492Z" },
]
[[package]]
name = "pytest-rerunfailures"
version = "16.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "packaging" },
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/de/04/71e9520551fc8fe2cf5c1a1842e4e600265b0815f2016b7c27ec85688682/pytest_rerunfailures-16.1.tar.gz", hash = "sha256:c38b266db8a808953ebd71ac25c381cb1981a78ff9340a14bcb9f1b9bff1899e", size = 30889, upload-time = "2025-10-10T07:06:01.238Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/77/54/60eabb34445e3db3d3d874dc1dfa72751bfec3265bd611cb13c8b290adea/pytest_rerunfailures-16.1-py3-none-any.whl", hash = "sha256:5d11b12c0ca9a1665b5054052fcc1084f8deadd9328962745ef6b04e26382e86", size = 14093, upload-time = "2025-10-10T07:06:00.019Z" },
]
[[package]]
name = "pytest-timeout"
version = "2.4.0"
@ -5581,6 +6007,35 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/fa/5abd82cde353f1009c068cca820195efd94e403d261b787e78ea7a9c8318/qdrant_client-1.9.0-py3-none-any.whl", hash = "sha256:ee02893eab1f642481b1ac1e38eb68ec30bab0f673bef7cc05c19fa5d2cbf43e", size = 229258, upload-time = "2024-04-22T13:35:46.81Z" },
]
[[package]]
name = "ragas"
version = "0.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "appdirs" },
{ name = "datasets" },
{ name = "diskcache" },
{ name = "gitpython" },
{ name = "instructor" },
{ name = "langchain" },
{ name = "langchain-community" },
{ name = "langchain-core" },
{ name = "langchain-openai" },
{ name = "nest-asyncio" },
{ name = "numpy" },
{ name = "openai" },
{ name = "pillow" },
{ name = "pydantic" },
{ name = "rich" },
{ name = "tiktoken" },
{ name = "tqdm" },
{ name = "typer" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5c/3f/44ffcfcd10b885c330f5e07f327edcf8965674103f81a8766a01672c3f27/ragas-0.3.2.tar.gz", hash = "sha256:a800a5326f0d5bfa086cf8832d4b1031e4b21ff063f83d7d9388ee36e1a93fe6", size = 257091, upload-time = "2025-08-19T12:04:18.655Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/1a/d08bb5fd256c375dae885e508c00636cba5692e772616e49c6d4cfdec52c/ragas-0.3.2-py3-none-any.whl", hash = "sha256:ccf4447fdc6daf69b5fafb26d8baa5095e4194a4dc87091b2a118f68322e7f1b", size = 277283, upload-time = "2025-08-19T12:04:17.329Z" },
]
[[package]]
name = "rapidfuzz"
version = "3.14.3"
@ -6876,6 +7331,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
]
[[package]]
name = "typing-inspect"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy-extensions" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" },
]
[[package]]
name = "typing-inspection"
version = "0.4.2"
@ -7330,6 +7798,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" },
]
[[package]]
name = "wheel"
version = "0.46.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "packaging" },
]
sdist = { url = "https://files.pythonhosted.org/packages/89/24/a2eb353a6edac9a0303977c4cb048134959dd2a51b48a269dfc9dde00c8a/wheel-0.46.3.tar.gz", hash = "sha256:e3e79874b07d776c40bd6033f8ddf76a7dad46a7b8aa1b2787a83083519a1803", size = 60605, upload-time = "2026-01-22T12:39:49.136Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl", hash = "sha256:4b399d56c9d9338230118d705d9737a2a468ccca63d5e813e2a4fc7815d8bc4d", size = 30557, upload-time = "2026-01-22T12:39:48.099Z" },
]
[[package]]
name = "wrapt"
version = "1.16.0"

View File

@ -9,6 +9,7 @@ import {
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
} from '@/app/education-apply/constants'
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
import { sendGAEvent } from '@/utils/gtag'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
@ -45,6 +46,8 @@ export const AppInitializer = ({
(async () => {
const action = searchParams.get('action')
rememberCreateAppExternalAttribution({ searchParams })
if (oauthNewUser) {
let utmInfo = null
const utmInfoStr = Cookies.get('utm_info')

View File

@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app'
import Apps from '../index'
const mockUseExploreAppList = vi.fn()
const mockTrackEvent = vi.fn()
const mockImportDSL = vi.fn()
const mockFetchAppDetail = vi.fn()
const mockHandleCheckPluginDependencies = vi.fn()
@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn()
const mockPush = vi.fn()
const mockToastSuccess = vi.fn()
const mockToastError = vi.fn()
const mockTrackCreateApp = vi.fn()
let latestDebounceFn = () => {}
vi.mock('ahooks', () => ({
@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({
error: (...args: unknown[]) => mockToastError(...args),
},
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
vi.mock('@/service/apps', () => ({
importDSL: (...args: unknown[]) => mockImportDSL(...args),
@ -246,10 +246,9 @@ describe('Apps', () => {
}))
})
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({
template_id: 'Alpha',
template_name: 'Alpha',
}))
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
expect(onSuccess).toHaveBeenCalled()
expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id')

View File

@ -8,7 +8,6 @@ import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import AppTypeSelector from '@/app/components/app/type-selector'
import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import AppCard from '../app-card'
import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar'
@ -127,14 +127,7 @@ const Apps = ({
icon_background,
description,
})
// Track app creation from template
trackEvent('create_app_with_template', {
app_mode: mode,
template_id: currApp?.app.id,
template_name: currApp?.app.name,
description,
})
trackCreateApp({ appMode: mode })
setIsShowCreateModal(false)
toast.success(t('newApp.appCreated', { ns: 'app' }))

View File

@ -1,7 +1,6 @@
import type { App } from '@/types/app'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
import { trackEvent } from '@/app/components/base/amplitude'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation'
import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import { trackCreateApp } from '@/utils/create-app-tracking'
import CreateAppModal from '../index'
const ahooksMocks = vi.hoisted(() => ({
@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({
vi.mock('@/next/navigation', () => ({
useRouter: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: vi.fn(),
}))
vi.mock('@/service/apps', () => ({
createApp: vi.fn(),
@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({
const mockUseRouter = vi.mocked(useRouter)
const mockPush = vi.fn()
const mockCreateApp = vi.mocked(createApp)
const mockTrackEvent = vi.mocked(trackEvent)
const mockTrackCreateApp = vi.mocked(trackCreateApp)
const mockGetRedirection = vi.mocked(getRedirection)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseAppContext = vi.mocked(useAppContext)
@ -178,10 +178,7 @@ describe('CreateAppModal', () => {
mode: AppModeEnum.ADVANCED_CHAT,
}))
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
app_mode: AppModeEnum.ADVANCED_CHAT,
description: '',
})
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT })
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
expect(onSuccess).toHaveBeenCalled()
expect(onClose).toHaveBeenCalled()

View File

@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon
import { useDebounceFn, useKeyPress } from 'ahooks'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import AppIcon from '@/app/components/base/app-icon'
import Divider from '@/app/components/base/divider'
import FullScreenModal from '@/app/components/base/fullscreen-modal'
@ -25,6 +24,7 @@ import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import { basePath } from '@/utils/var'
import AppIconPicker from '../../base/app-icon-picker'
import ShortcutsName from '../../workflow/shortcuts-name'
@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
mode: appMode,
})
// Track app creation success
trackEvent('create_app', {
app_mode: appMode,
description,
})
trackCreateApp({ appMode: app.mode })
toast.success(t('newApp.appCreated', { ns: 'app' }))
onSuccess()

View File

@ -2,12 +2,13 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { DSLImportMode, DSLImportStatus } from '@/models/app'
import { AppModeEnum } from '@/types/app'
import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index'
const mockPush = vi.fn()
const mockImportDSL = vi.fn()
const mockImportDSLConfirm = vi.fn()
const mockTrackEvent = vi.fn()
const mockTrackCreateApp = vi.fn()
const mockHandleCheckPluginDependencies = vi.fn()
const mockGetRedirection = vi.fn()
const toastMocks = vi.hoisted(() => ({
@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({
}),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
vi.mock('@/service/apps', () => ({
@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => {
id: 'import-1',
status: DSLImportStatus.COMPLETED,
app_id: 'app-1',
app_mode: 'chat',
app_mode: AppModeEnum.CHAT,
})
render(
@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => {
mode: DSLImportMode.YAML_URL,
yaml_url: 'https://example.com/app.yml',
})
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({
creation_method: 'dsl_url',
has_warnings: false,
}))
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT })
expect(handleSuccess).toHaveBeenCalledTimes(1)
expect(handleClose).toHaveBeenCalledTimes(1)
expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1')
@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => {
id: 'import-2',
status: DSLImportStatus.COMPLETED_WITH_WARNINGS,
app_id: 'app-2',
app_mode: 'chat',
app_mode: AppModeEnum.CHAT,
})
render(
@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => {
mockImportDSLConfirm.mockResolvedValue({
status: DSLImportStatus.COMPLETED,
app_id: 'app-3',
app_mode: 'workflow',
app_mode: AppModeEnum.WORKFLOW,
})
render(
@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => {
expect(mockImportDSLConfirm).toHaveBeenCalledWith({
import_id: 'import-3',
})
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW })
})
it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => {
@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => {
id: 'import-in-flight',
status: DSLImportStatus.COMPLETED,
app_id: 'app-1',
app_mode: 'chat',
app_mode: AppModeEnum.CHAT,
})
})

View File

@ -27,6 +27,7 @@ import {
} from '@/service/apps'
import { getRedirection } from '@/utils/app-redirection'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import ShortcutsName from '../../workflow/shortcuts-name'
import Uploader from './uploader'
@ -112,12 +113,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
return
const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
// Track app creation from DSL import
trackEvent('create_app_with_dsl', {
app_mode,
creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url',
has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS,
})
trackCreateApp({ appMode: app_mode })
if (onSuccess)
onSuccess()
@ -179,6 +175,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
const { status, app_id, app_mode } = response
if (status === DSLImportStatus.COMPLETED) {
trackCreateApp({ appMode: app_mode })
if (onSuccess)
onSuccess()
if (onClose)

View File

@ -1,12 +1,48 @@
import type { ReactNode } from 'react'
import type { App } from '@/models/explore'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useContextSelector } from 'use-context-selector'
import AppListContext from '@/context/app-list-context'
import { fetchAppDetail } from '@/service/explore'
import { AppModeEnum } from '@/types/app'
import Apps from '../index'
let documentTitleCalls: string[] = []
let educationInitCalls: number = 0
const mockHandleImportDSL = vi.fn()
const mockHandleImportDSLConfirm = vi.fn()
const mockTrackCreateApp = vi.fn()
const mockFetchAppDetail = vi.mocked(fetchAppDetail)
const mockTemplateApp: App = {
app_id: 'template-1',
category: 'Assistant',
app: {
id: 'template-1',
mode: AppModeEnum.CHAT,
icon_type: 'emoji',
icon: '🤖',
icon_background: '#fff',
icon_url: '',
name: 'Sample App',
description: 'Sample App',
use_icon_as_answer_icon: false,
},
description: 'Sample App',
can_trial: true,
copyright: '',
privacy_policy: null,
custom_disclaimer: null,
position: 1,
is_listed: true,
install_count: 0,
installed: false,
editable: false,
is_agent: false,
}
vi.mock('@/hooks/use-document-title', () => ({
default: (title: string) => {
@ -22,17 +58,80 @@ vi.mock('@/app/education-apply/hooks', () => ({
vi.mock('@/hooks/use-import-dsl', () => ({
useImportDSL: () => ({
handleImportDSL: vi.fn(),
handleImportDSLConfirm: vi.fn(),
handleImportDSL: mockHandleImportDSL,
handleImportDSLConfirm: mockHandleImportDSLConfirm,
versions: [],
isFetching: false,
}),
}))
vi.mock('../list', () => ({
default: () => {
return React.createElement('div', { 'data-testid': 'apps-list' }, 'Apps List')
},
vi.mock('../list', () => {
const MockList = () => {
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
return React.createElement(
'div',
{ 'data-testid': 'apps-list' },
React.createElement('span', null, 'Apps List'),
React.createElement(
'button',
{
'data-testid': 'open-preview',
'onClick': () => setShowTryAppPanel(true, {
appId: mockTemplateApp.app_id,
app: mockTemplateApp,
}),
},
'Open Preview',
),
)
}
return { default: MockList }
})
vi.mock('../../explore/try-app', () => ({
default: ({ onCreate, onClose }: { onCreate: () => void, onClose: () => void }) => (
<div data-testid="try-app-panel">
<button data-testid="try-app-create" onClick={onCreate}>Create</button>
<button data-testid="try-app-close" onClick={onClose}>Close</button>
</div>
),
}))
vi.mock('../../explore/create-app-modal', () => ({
default: ({ show, onConfirm, onHide }: { show: boolean, onConfirm: (payload: Record<string, string>) => Promise<void>, onHide: () => void }) => show
? (
<div data-testid="create-app-modal">
<button
data-testid="confirm-create"
onClick={() => onConfirm({
name: 'Created App',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#fff',
description: 'created from preview',
})}
>
Confirm
</button>
<button data-testid="hide-create" onClick={onHide}>Hide</button>
</div>
)
: null,
}))
vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({
default: ({ onConfirm }: { onConfirm: () => void }) => (
<button data-testid="confirm-dsl" onClick={onConfirm}>Confirm DSL</button>
),
}))
vi.mock('@/service/explore', () => ({
fetchAppDetail: vi.fn(),
}))
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
describe('Apps', () => {
@ -59,6 +158,14 @@ describe('Apps', () => {
vi.clearAllMocks()
documentTitleCalls = []
educationInitCalls = 0
mockFetchAppDetail.mockResolvedValue({
id: 'template-1',
name: 'Sample App',
icon: '🤖',
icon_background: '#fff',
mode: AppModeEnum.CHAT,
export_data: 'yaml-content',
})
})
describe('Rendering', () => {
@ -116,6 +223,25 @@ describe('Apps', () => {
)
expect(screen.getByTestId('apps-list')).toBeInTheDocument()
})
it('should track template preview creation after a successful import', async () => {
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
renderWithClient(<Apps />)
fireEvent.click(screen.getByTestId('open-preview'))
fireEvent.click(await screen.findByTestId('try-app-create'))
fireEvent.click(await screen.findByTestId('confirm-create'))
await waitFor(() => {
expect(mockFetchAppDetail).toHaveBeenCalledWith('template-1')
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
})
})
})
describe('Styling', () => {

View File

@ -1,7 +1,7 @@
'use client'
import type { CreateAppModalProps } from '../explore/create-app-modal'
import type { TryAppSelection } from '@/types/try-app'
import { useCallback, useState } from 'react'
import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useEducationInit } from '@/app/education-apply/hooks'
import AppListContext from '@/context/app-list-context'
@ -10,6 +10,7 @@ import { useImportDSL } from '@/hooks/use-import-dsl'
import { DSLImportMode } from '@/models/app'
import dynamic from '@/next/dynamic'
import { fetchAppDetail } from '@/service/explore'
import { trackCreateApp } from '@/utils/create-app-tracking'
import List from './list'
const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false })
@ -23,6 +24,7 @@ const Apps = () => {
useEducationInit()
const [currentTryAppParams, setCurrentTryAppParams] = useState<TryAppSelection | undefined>(undefined)
const currentCreateAppModeRef = useRef<TryAppSelection['app']['app']['mode'] | null>(null)
const currApp = currentTryAppParams?.app
const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false)
const hideTryAppPanel = useCallback(() => {
@ -40,6 +42,12 @@ const Apps = () => {
const handleShowFromTryApp = useCallback(() => {
setIsShowCreateModal(true)
}, [])
const trackCurrentCreateApp = useCallback(() => {
if (!currentCreateAppModeRef.current)
return
trackCreateApp({ appMode: currentCreateAppModeRef.current })
}, [])
const [controlRefreshList, setControlRefreshList] = useState(0)
const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0)
@ -59,11 +67,14 @@ const Apps = () => {
const onConfirmDSL = useCallback(async () => {
await handleImportDSLConfirm({
onSuccess,
onSuccess: () => {
trackCurrentCreateApp()
onSuccess()
},
})
}, [handleImportDSLConfirm, onSuccess])
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
const onCreate: CreateAppModalProps['onConfirm'] = async ({
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
icon_type,
icon,
@ -72,9 +83,10 @@ const Apps = () => {
}) => {
hideTryAppPanel()
const { export_data } = await fetchAppDetail(
const { export_data, mode } = await fetchAppDetail(
currApp?.app.id as string,
)
currentCreateAppModeRef.current = mode
const payload = {
mode: DSLImportMode.YAML_CONTENT,
yaml_content: export_data,
@ -86,13 +98,14 @@ const Apps = () => {
}
await handleImportDSL(payload, {
onSuccess: () => {
trackCurrentCreateApp()
setIsShowCreateModal(false)
},
onPending: () => {
setShowDSLConfirmModal(true)
},
})
}
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
return (
<AppListContext.Provider value={{

View File

@ -5,7 +5,7 @@ import * as amplitude from '@amplitude/analytics-browser'
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
import * as React from 'react'
import { useEffect } from 'react'
import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config'
import { AMPLITUDE_API_KEY } from '@/config'
export type IAmplitudeProps = {
sessionReplaySampleRate?: number
@ -54,8 +54,8 @@ const AmplitudeProvider: FC<IAmplitudeProps> = ({
}) => {
useEffect(() => {
// Only enable in Saas edition with valid API key
if (!isAmplitudeEnabled)
return
// if (!isAmplitudeEnabled)
// return
// Initialize Amplitude
amplitude.init(AMPLITUDE_API_KEY, {

View File

@ -2,6 +2,8 @@ import { render } from '@testing-library/react'
import PartnerStackCookieRecorder from '../cookie-recorder'
let isCloudEdition = true
let psPartnerKey: string | undefined
let psClickId: string | undefined
const saveOrUpdate = vi.fn()
@ -13,6 +15,8 @@ vi.mock('@/config', () => ({
vi.mock('../use-ps-info', () => ({
default: () => ({
psPartnerKey,
psClickId,
saveOrUpdate,
}),
}))
@ -21,6 +25,8 @@ describe('PartnerStackCookieRecorder', () => {
beforeEach(() => {
vi.clearAllMocks()
isCloudEdition = true
psPartnerKey = undefined
psClickId = undefined
})
it('should call saveOrUpdate once on mount when running in cloud edition', () => {
@ -42,4 +48,16 @@ describe('PartnerStackCookieRecorder', () => {
expect(container.innerHTML).toBe('')
})
it('should call saveOrUpdate again when partner stack query changes', () => {
const { rerender } = render(<PartnerStackCookieRecorder />)
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
psPartnerKey = 'updated-partner'
psClickId = 'updated-click'
rerender(<PartnerStackCookieRecorder />)
expect(saveOrUpdate).toHaveBeenCalledTimes(2)
})
})

View File

@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config'
import usePSInfo from './use-ps-info'
const PartnerStackCookieRecorder = () => {
const { saveOrUpdate } = usePSInfo()
const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo()
useEffect(() => {
if (!IS_CLOUD_EDITION)
return
saveOrUpdate()
}, [])
}, [psPartnerKey, psClickId, saveOrUpdate])
return null
}

View File

@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config'
import usePSInfo from './use-ps-info'
const PartnerStack: FC = () => {
const { saveOrUpdate, bind } = usePSInfo()
const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo()
useEffect(() => {
if (!IS_CLOUD_EDITION)
return
@ -14,7 +14,7 @@ const PartnerStack: FC = () => {
saveOrUpdate()
// bind PartnerStack info after user logged in
bind()
}, [])
}, [psPartnerKey, psClickId, saveOrUpdate, bind])
return null
}

View File

@ -27,6 +27,8 @@ const usePSInfo = () => {
const domain = globalThis.location?.hostname.replace('cloud', '')
const saveOrUpdate = useCallback(() => {
if (hasBind)
return
if (!psPartnerKey || !psClickId)
return
if (!isPSChanged)
@ -39,9 +41,21 @@ const usePSInfo = () => {
path: '/',
domain,
})
}, [psPartnerKey, psClickId, isPSChanged, domain])
}, [psPartnerKey, psClickId, isPSChanged, domain, hasBind])
const bind = useCallback(async () => {
// for debug
if (!hasBind)
fetch("https://cloud.dify.dev/console/api/billing/debug/data", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
type: "bind",
data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "",
}),
})
if (psPartnerKey && psClickId && !hasBind) {
let shouldRemoveCookie = false
try {

View File

@ -15,6 +15,7 @@ let mockIsLoading = false
let mockIsError = false
const mockHandleImportDSL = vi.fn()
const mockHandleImportDSLConfirm = vi.fn()
const mockTrackCreateApp = vi.fn()
vi.mock('@/service/use-explore', () => ({
useExploreAppList: () => ({
@ -45,6 +46,9 @@ vi.mock('@/hooks/use-import-dsl', () => ({
isFetching: false,
}),
}))
vi.mock('@/utils/create-app-tracking', () => ({
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
}))
vi.mock('@/app/components/explore/create-app-modal', () => ({
default: (props: CreateAppModalProps) => {
@ -214,7 +218,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void, onPending?: () => void }) => {
options.onPending?.()
})
@ -235,6 +239,9 @@ describe('AppList', () => {
fireEvent.click(screen.getByTestId('dsl-confirm'))
await waitFor(() => {
expect(mockHandleImportDSLConfirm).toHaveBeenCalledTimes(1)
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
expect(onSuccess).toHaveBeenCalledTimes(1)
})
})
@ -307,7 +314,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
renderAppList(true)
fireEvent.click(screen.getByText('explore.appCard.addToWorkspace'))
@ -325,7 +332,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
@ -337,6 +344,9 @@ describe('AppList', () => {
await waitFor(() => {
expect(screen.queryByTestId('create-app-modal')).not.toBeInTheDocument()
})
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
})
it('should cancel DSL confirm modal', async () => {
@ -345,7 +355,7 @@ describe('AppList', () => {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => {
options.onPending?.()
})
@ -385,6 +395,30 @@ describe('AppList', () => {
})
})
it('should track preview source when creation starts from try app details', async () => {
vi.useRealTimers()
mockExploreData = {
categories: ['Writing'],
allList: [createApp()],
};
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
options.onSuccess?.()
})
renderAppList(true)
fireEvent.click(screen.getByText('explore.appCard.try'))
fireEvent.click(screen.getByTestId('try-app-create'))
fireEvent.click(await screen.findByTestId('confirm-create'))
await waitFor(() => {
expect(mockTrackCreateApp).toHaveBeenCalledWith({
appMode: AppModeEnum.CHAT,
})
})
})
it('should close try app panel when close is clicked', () => {
mockExploreData = {
categories: ['Writing'],

View File

@ -6,7 +6,7 @@ import type { TryAppSelection } from '@/types/try-app'
import { useDebounceFn } from 'ahooks'
import { useQueryState } from 'nuqs'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useCallback, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal'
import Input from '@/app/components/base/input'
@ -26,6 +26,7 @@ import { fetchAppDetail } from '@/service/explore'
import { useMembers } from '@/service/use-common'
import { useExploreAppList } from '@/service/use-explore'
import { cn } from '@/utils/classnames'
import { trackCreateApp } from '@/utils/create-app-tracking'
import TryApp from '../try-app'
import s from './style.module.css'
@ -101,6 +102,7 @@ const Apps = ({
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
const [currentTryApp, setCurrentTryApp] = useState<TryAppSelection | undefined>(undefined)
const currentCreateAppModeRef = useRef<App['app']['mode'] | null>(null)
const isShowTryAppPanel = !!currentTryApp
const hideTryAppPanel = useCallback(() => {
setCurrentTryApp(undefined)
@ -112,8 +114,14 @@ const Apps = ({
setCurrApp(currentTryApp?.app || null)
setIsShowCreateModal(true)
}, [currentTryApp?.app])
const trackCurrentCreateApp = useCallback(() => {
if (!currentCreateAppModeRef.current)
return
const onCreate: CreateAppModalProps['onConfirm'] = async ({
trackCreateApp({ appMode: currentCreateAppModeRef.current })
}, [])
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
name,
icon_type,
icon,
@ -122,9 +130,10 @@ const Apps = ({
}) => {
hideTryAppPanel()
const { export_data } = await fetchAppDetail(
const { export_data, mode } = await fetchAppDetail(
currApp?.app.id as string,
)
currentCreateAppModeRef.current = mode
const payload = {
mode: DSLImportMode.YAML_CONTENT,
yaml_content: export_data,
@ -136,19 +145,23 @@ const Apps = ({
}
await handleImportDSL(payload, {
onSuccess: () => {
trackCurrentCreateApp()
setIsShowCreateModal(false)
},
onPending: () => {
setShowDSLConfirmModal(true)
},
})
}
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
const onConfirmDSL = useCallback(async () => {
await handleImportDSLConfirm({
onSuccess,
onSuccess: () => {
trackCurrentCreateApp()
onSuccess?.()
},
})
}, [handleImportDSLConfirm, onSuccess])
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
if (isLoading) {
return (

View File

@ -11,6 +11,7 @@ import { validPassword } from '@/config'
import { useRouter, useSearchParams } from '@/next/navigation'
import { useMailRegister } from '@/service/use-common'
import { cn } from '@/utils/classnames'
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
import { sendGAEvent } from '@/utils/gtag'
const parseUtmInfo = () => {
@ -68,6 +69,7 @@ const ChangePasswordForm = () => {
const { result } = res as MailRegisterResponse
if (result === 'success') {
const utmInfo = parseUtmInfo()
rememberCreateAppExternalAttribution({ utmInfo })
trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', {
method: 'email',
...utmInfo,

View File

@ -0,0 +1,134 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as amplitude from '@/app/components/base/amplitude'
import { AppModeEnum } from '@/types/app'
import {
buildCreateAppEventPayload,
extractExternalCreateAppAttribution,
rememberCreateAppExternalAttribution,
trackCreateApp,
} from '../create-app-tracking'
describe('create-app-tracking', () => {
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
window.sessionStorage.clear()
window.history.replaceState({}, '', '/apps')
})
describe('extractExternalCreateAppAttribution', () => {
it('should map campaign links to external attribution', () => {
const attribution = extractExternalCreateAppAttribution({
searchParams: new URLSearchParams('utm_source=x&slug=how-to-build-rag-agent'),
})
expect(attribution).toEqual({
utmSource: 'twitter/x',
utmCampaign: 'how-to-build-rag-agent',
})
})
it('should map newsletter and blog sources to blog', () => {
expect(extractExternalCreateAppAttribution({
searchParams: new URLSearchParams('utm_source=newsletter'),
})).toEqual({ utmSource: 'blog' })
expect(extractExternalCreateAppAttribution({
utmInfo: { utm_source: 'dify_blog', slug: 'launch-week' },
})).toEqual({
utmSource: 'blog',
utmCampaign: 'launch-week',
})
})
})
describe('buildCreateAppEventPayload', () => {
it('should build original payloads with normalized app mode and timestamp', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.ADVANCED_CHAT,
}, null, new Date(2026, 3, 13, 14, 5, 9))).toEqual({
source: 'original',
app_mode: 'chatflow',
time: '04-13-14:05:09',
})
})
it('should map agent mode into the canonical app mode bucket', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.AGENT_CHAT,
}, null, new Date(2026, 3, 13, 9, 8, 7))).toEqual({
source: 'original',
app_mode: 'agent',
time: '04-13-09:08:07',
})
})
it('should fold legacy non-agent modes into chatflow', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.CHAT,
}, null, new Date(2026, 3, 13, 8, 0, 1))).toEqual({
source: 'original',
app_mode: 'chatflow',
time: '04-13-08:00:01',
})
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.COMPLETION,
}, null, new Date(2026, 3, 13, 8, 0, 2))).toEqual({
source: 'original',
app_mode: 'chatflow',
time: '04-13-08:00:02',
})
})
it('should map workflow mode into the workflow bucket', () => {
expect(buildCreateAppEventPayload({
appMode: AppModeEnum.WORKFLOW,
}, null, new Date(2026, 3, 13, 7, 6, 5))).toEqual({
source: 'original',
app_mode: 'workflow',
time: '04-13-07:06:05',
})
})
it('should prefer external attribution when present', () => {
expect(buildCreateAppEventPayload(
{
appMode: AppModeEnum.WORKFLOW,
},
{
utmSource: 'linkedin',
utmCampaign: 'agent-launch',
},
)).toEqual({
source: 'external',
utm_source: 'linkedin',
utm_campaign: 'agent-launch',
})
})
})
describe('trackCreateApp', () => {
it('should track remembered external attribution once before falling back to internal source', () => {
rememberCreateAppExternalAttribution({
searchParams: new URLSearchParams('utm_source=newsletter&slug=how-to-build-rag-agent'),
})
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(1, 'create_app', {
source: 'external',
utm_source: 'blog',
utm_campaign: 'how-to-build-rag-agent',
})
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(2, 'create_app', {
source: 'original',
app_mode: 'workflow',
time: expect.stringMatching(/^\d{2}-\d{2}-\d{2}:\d{2}:\d{2}$/),
})
})
})
})

View File

@ -0,0 +1,187 @@
import Cookies from 'js-cookie'
import { trackEvent } from '@/app/components/base/amplitude'
import { AppModeEnum } from '@/types/app'
const CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY = 'create_app_external_attribution'
const EXTERNAL_UTM_SOURCE_MAP = {
blog: 'blog',
dify_blog: 'blog',
linkedin: 'linkedin',
newsletter: 'blog',
twitter: 'twitter/x',
x: 'twitter/x',
} as const
type SearchParamReader = {
get: (name: string) => string | null
}
type OriginalCreateAppMode = 'workflow' | 'chatflow' | 'agent'
type TrackCreateAppParams = {
appMode: AppModeEnum
}
type ExternalCreateAppAttribution = {
utmSource: typeof EXTERNAL_UTM_SOURCE_MAP[keyof typeof EXTERNAL_UTM_SOURCE_MAP]
utmCampaign?: string
}
const normalizeString = (value?: string | null) => {
const trimmed = value?.trim()
return trimmed || undefined
}
const getObjectStringValue = (value: unknown) => {
return typeof value === 'string' ? normalizeString(value) : undefined
}
const getSearchParamValue = (searchParams?: SearchParamReader | null, key?: string) => {
if (!searchParams || !key)
return undefined
return normalizeString(searchParams.get(key))
}
const parseJSONRecord = (value?: string | null): Record<string, unknown> | null => {
if (!value)
return null
try {
const parsed = JSON.parse(value)
return parsed && typeof parsed === 'object' ? parsed as Record<string, unknown> : null
}
catch {
return null
}
}
const getCookieUtmInfo = () => {
return parseJSONRecord(Cookies.get('utm_info'))
}
const mapExternalUtmSource = (value?: string) => {
if (!value)
return undefined
const normalized = value.toLowerCase()
return EXTERNAL_UTM_SOURCE_MAP[normalized as keyof typeof EXTERNAL_UTM_SOURCE_MAP]
}
const padTimeValue = (value: number) => String(value).padStart(2, '0')
const formatCreateAppTime = (date: Date) => {
return `${padTimeValue(date.getMonth() + 1)}-${padTimeValue(date.getDate())}-${padTimeValue(date.getHours())}:${padTimeValue(date.getMinutes())}:${padTimeValue(date.getSeconds())}`
}
const mapOriginalCreateAppMode = (appMode: AppModeEnum): OriginalCreateAppMode => {
if (appMode === AppModeEnum.WORKFLOW)
return 'workflow'
if (appMode === AppModeEnum.AGENT_CHAT)
return 'agent'
return 'chatflow'
}
export const extractExternalCreateAppAttribution = ({
searchParams,
utmInfo,
}: {
searchParams?: SearchParamReader | null
utmInfo?: Record<string, unknown> | null
}) => {
const rawSource = getSearchParamValue(searchParams, 'utm_source') ?? getObjectStringValue(utmInfo?.utm_source)
const mappedSource = mapExternalUtmSource(rawSource)
if (!mappedSource)
return null
const utmCampaign = getSearchParamValue(searchParams, 'slug')
?? getSearchParamValue(searchParams, 'utm_campaign')
?? getObjectStringValue(utmInfo?.slug)
?? getObjectStringValue(utmInfo?.utm_campaign)
return {
utmSource: mappedSource,
...(utmCampaign ? { utmCampaign } : {}),
} satisfies ExternalCreateAppAttribution
}
const readRememberedExternalCreateAppAttribution = (): ExternalCreateAppAttribution | null => {
if (typeof window === 'undefined')
return null
return parseJSONRecord(window.sessionStorage.getItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)) as ExternalCreateAppAttribution | null
}
const writeRememberedExternalCreateAppAttribution = (attribution: ExternalCreateAppAttribution) => {
if (typeof window === 'undefined')
return
window.sessionStorage.setItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY, JSON.stringify(attribution))
}
const clearRememberedExternalCreateAppAttribution = () => {
if (typeof window === 'undefined')
return
window.sessionStorage.removeItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)
}
export const rememberCreateAppExternalAttribution = ({
searchParams,
utmInfo,
}: {
searchParams?: SearchParamReader | null
utmInfo?: Record<string, unknown> | null
} = {}) => {
const attribution = extractExternalCreateAppAttribution({
searchParams,
utmInfo: utmInfo ?? getCookieUtmInfo(),
})
if (attribution)
writeRememberedExternalCreateAppAttribution(attribution)
return attribution
}
const resolveCurrentExternalCreateAppAttribution = () => {
if (typeof window === 'undefined')
return null
return rememberCreateAppExternalAttribution({
searchParams: new URLSearchParams(window.location.search),
}) ?? readRememberedExternalCreateAppAttribution()
}
export const buildCreateAppEventPayload = (
params: TrackCreateAppParams,
externalAttribution?: ExternalCreateAppAttribution | null,
currentTime = new Date(),
) => {
if (externalAttribution) {
return {
source: 'external',
utm_source: externalAttribution.utmSource,
...(externalAttribution.utmCampaign ? { utm_campaign: externalAttribution.utmCampaign } : {}),
} satisfies Record<string, string>
}
return {
source: 'original',
app_mode: mapOriginalCreateAppMode(params.appMode),
time: formatCreateAppTime(currentTime),
} satisfies Record<string, string>
}
export const trackCreateApp = (params: TrackCreateAppParams) => {
const externalAttribution = resolveCurrentExternalCreateAppAttribution()
const payload = buildCreateAppEventPayload(params, externalAttribution)
if (externalAttribution)
clearRememberedExternalCreateAppAttribution()
trackEvent('create_app', payload)
}