mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
feat: evaluation (#35251)
Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: hj24 <mambahj24@gmail.com> Co-authored-by: hj24 <huangjian@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
parent
12b1cc3d2e
commit
f9b76f0f52
@ -1366,6 +1366,32 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class EvaluationConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for evaluation runtime
|
||||
"""
|
||||
|
||||
EVALUATION_FRAMEWORK: str = Field(
|
||||
description="Evaluation framework to use (ragas/deepeval/none)",
|
||||
default="none",
|
||||
)
|
||||
|
||||
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent evaluation runs per tenant",
|
||||
default=3,
|
||||
)
|
||||
|
||||
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
|
||||
description="Maximum number of rows allowed in an evaluation dataset",
|
||||
default=500,
|
||||
)
|
||||
|
||||
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for a single evaluation task",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1378,6 +1404,7 @@ class FeatureConfig(
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
EvaluationConfig,
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
|
||||
@ -107,6 +107,9 @@ from .datasets.rag_pipeline import (
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import evaluation controllers
|
||||
from .evaluation import evaluation
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
@ -117,6 +120,9 @@ from .explore import (
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@ -130,6 +136,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -167,6 +174,7 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"evaluation",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
@ -201,6 +209,9 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
|
||||
@ -329,6 +329,7 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
workflow_type: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -363,6 +364,7 @@ class AppDetail(ResponseModel):
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
access_mode: str | None = None
|
||||
workflow_type: str | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@ -505,6 +507,17 @@ class AppListApi(Resource):
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
|
||||
workflow_type_map: dict[str, str] = {}
|
||||
if workflow_ids:
|
||||
rows = db.session.execute(
|
||||
select(Workflow.id, Workflow.type).where(Workflow.id.in_(workflow_ids))
|
||||
).all()
|
||||
workflow_type_map = {str(row.id): row.type for row in rows}
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.workflow_type = workflow_type_map.get(str(app.workflow_id)) if app.workflow_id else None
|
||||
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
|
||||
@ -551,6 +564,14 @@ class AppApi(Resource):
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
if app_model.workflow_id:
|
||||
row = db.session.execute(
|
||||
select(Workflow.type).where(Workflow.id == app_model.workflow_id)
|
||||
).scalar()
|
||||
app_model.workflow_type = row if row else None
|
||||
else:
|
||||
app_model.workflow_type = None
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -46,7 +46,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -150,6 +150,24 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
keyword: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowTypeConvertQuery(BaseModel):
|
||||
target_type: Literal["workflow", "evaluation"]
|
||||
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -173,6 +191,7 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowTypeConvertQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -845,6 +864,54 @@ class PublishedWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
|
||||
class EvaluationPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("publish_evaluation_workflow")
|
||||
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@console_ns.response(200, "Evaluation workflow published successfully")
|
||||
@console_ns.response(400, "Invalid workflow or unsupported node type")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Publish draft workflow as evaluation workflow.
|
||||
|
||||
Evaluation workflows cannot include trigger or human-input nodes.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = workflow_service.publish_evaluation_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
# Keep workflow_id aligned with the latest published workflow.
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_default_block_configs")
|
||||
@ -1016,6 +1083,51 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
|
||||
class WorkflowTypeConvertApi(Resource):
|
||||
@console_ns.doc("convert_published_workflow_type")
|
||||
@console_ns.doc(description="Convert current effective published workflow type in-place")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
|
||||
@console_ns.response(200, "Workflow type converted successfully")
|
||||
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
target_type = WorkflowType.value_of(args.target_type)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow = workflow_service.convert_published_workflow_type(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
target_type=target_type,
|
||||
account=current_user,
|
||||
)
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"workflow_id": workflow.id,
|
||||
"type": workflow.type.value,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
@ -10,6 +12,7 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -77,3 +80,39 @@ class PartnerTenants(Resource):
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
|
||||
_DEBUG_KEY = "billing:debug"
|
||||
_DEBUG_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class DebugDataPayload(BaseModel):
|
||||
type: str = Field(..., min_length=1, description="Data type key")
|
||||
data: str = Field(..., min_length=1, description="Data value to append")
|
||||
|
||||
|
||||
@console_ns.route("/billing/debug/data")
|
||||
class DebugData(Resource):
|
||||
def post(self):
|
||||
body = DebugDataPayload.model_validate(request.get_json(force=True))
|
||||
item = json.dumps({
|
||||
"type": body.type,
|
||||
"data": body.data,
|
||||
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
})
|
||||
redis_client.lpush(_DEBUG_KEY, item)
|
||||
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
|
||||
return {"result": "ok"}, 201
|
||||
|
||||
def get(self):
|
||||
recent = request.args.get("recent", 10, type=int)
|
||||
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
|
||||
return {
|
||||
"data": [
|
||||
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
|
||||
]
|
||||
}
|
||||
|
||||
def delete(self):
|
||||
redis_client.delete(_DEBUG_KEY)
|
||||
return {"result": "ok"}
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import request
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -22,6 +25,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -30,6 +34,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
@ -50,12 +55,19 @@ from fields.dataset_fields import (
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
@ -983,3 +995,432 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
|
||||
# ---- Knowledge Base Retrieval Evaluation ----
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
|
||||
class DatasetEvaluationTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Download evaluation dataset template for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
|
||||
class DatasetEvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(
|
||||
session, current_tenant_id, "dataset", dataset_id_str
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration saved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id):
|
||||
"""Save evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
|
||||
class DatasetEvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Start an evaluation run for the knowledge base retrieval."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
|
||||
class DatasetEvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation run history for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
|
||||
class DatasetEvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, run_id):
|
||||
"""Get evaluation run detail including per-item results."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return {
|
||||
"run": _serialize_dataset_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class DatasetEvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, run_id):
|
||||
"""Cancel a running knowledge base evaluation."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
|
||||
class DatasetEvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_metrics")
|
||||
@console_ns.response(200, "Available retrieval metrics retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get available evaluation metrics for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return {
|
||||
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.KNOWLEDGE_BASE)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
|
||||
class DatasetEvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, file_id):
|
||||
"""Download evaluation test file or result file for the knowledge base."""
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
file_id_str = str(file_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id_str,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
1
api/controllers/console/evaluation/__init__.py
Normal file
1
api/controllers/console/evaluation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Evaluation controller module
|
||||
869
api/controllers/console/evaluation/evaluation.py
Normal file
869
api/controllers/console/evaluation/evaluation.py
Normal file
@ -0,0 +1,869 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.workflow import WorkflowListQuery
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.member_fields import simple_account_fields
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Dataset
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Valid evaluation target types
|
||||
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
"""Query parameters for version endpoint."""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
VersionQuery,
|
||||
)
|
||||
|
||||
|
||||
# Response field definitions
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_fields = {
|
||||
"created_at": TimestampField,
|
||||
"created_by": fields.String,
|
||||
"test_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTestFile",
|
||||
file_info_fields,
|
||||
)
|
||||
),
|
||||
"result_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationResultFile",
|
||||
file_info_fields,
|
||||
),
|
||||
allow_null=True,
|
||||
),
|
||||
"version": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_list_model = console_ns.model(
|
||||
"EvaluationLogList",
|
||||
{
|
||||
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||
},
|
||||
)
|
||||
|
||||
evaluation_default_metric_node_info_fields = {
|
||||
"node_id": fields.String,
|
||||
"type": fields.String,
|
||||
"title": fields.String,
|
||||
}
|
||||
evaluation_default_metric_item_fields = {
|
||||
"metric": fields.String,
|
||||
"value_type": fields.String,
|
||||
"node_info_list": fields.List(
|
||||
fields.Nested(
|
||||
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
customized_metrics_fields = {
|
||||
"evaluation_workflow_id": fields.String,
|
||||
"input_fields": fields.Raw,
|
||||
"output_fields": fields.Raw,
|
||||
}
|
||||
|
||||
judgment_condition_fields = {
|
||||
"variable_selector": fields.List(fields.String),
|
||||
"comparison_operator": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
judgment_config_fields = {
|
||||
"logical_operator": fields.String,
|
||||
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
|
||||
}
|
||||
|
||||
evaluation_detail_fields = {
|
||||
"evaluation_model": fields.String,
|
||||
"evaluation_model_provider": fields.String,
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
|
||||
allow_null=True,
|
||||
),
|
||||
"customized_metrics": fields.Nested(
|
||||
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
"judgment_config": fields.Nested(
|
||||
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
}
|
||||
|
||||
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||
|
||||
available_evaluation_workflow_list_fields = {
|
||||
"id": fields.String,
|
||||
"app_id": fields.String,
|
||||
"app_name": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.String,
|
||||
"marked_name": fields.String,
|
||||
"marked_comment": fields.String,
|
||||
"hash": fields.String,
|
||||
"created_by": fields.Nested(simple_account_fields),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_fields = {
|
||||
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_model = console_ns.model(
|
||||
"AvailableEvaluationWorkflowPagination",
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
)
|
||||
|
||||
evaluation_default_metrics_response_model = console_ns.model(
|
||||
"EvaluationDefaultMetricsResponse",
|
||||
{
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_evaluation_target(view_func: Callable[P, R]):
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
if target_type not in EVALUATE_TARGET_TYPES:
|
||||
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
target_id = str(target_id)
|
||||
|
||||
# Remove path parameters
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet, Dataset] | None = None
|
||||
|
||||
if target_type == "app":
|
||||
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
elif target_type == "snippets":
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
elif target_type == "knowledge":
|
||||
target = (db.session.query(Dataset)
|
||||
.where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id)
|
||||
.first())
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
|
||||
kwargs["target"] = target
|
||||
kwargs["target_type"] = target_type
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
class EvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation configuration for the target.
|
||||
|
||||
Returns evaluation configuration including model settings,
|
||||
metrics config, and judgement conditions.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation configuration saved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Save evaluation configuration for the target.
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
body = request.get_json(force=True)
|
||||
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||
class EvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation run history for the target.
|
||||
|
||||
Returns a paginated list of evaluation runs.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
|
||||
class EvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Expects JSON body with:
|
||||
- file_id: uploaded dataset file ID
|
||||
- evaluation_model: evaluation model name
|
||||
- evaluation_model_provider: evaluation model provider
|
||||
- default_metrics: list of default metric objects
|
||||
- customized_metrics: customized metrics object (optional)
|
||||
- judgment_config: judgment conditions config (optional)
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
# Validate and parse request body
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
# Load dataset file
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
|
||||
class EvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""
|
||||
Get evaluation run detail including items.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"run": _serialize_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class EvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""Cancel a running evaluation."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return _serialize_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
|
||||
class EvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics retrieved")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get available evaluation metrics for the current framework.
|
||||
"""
|
||||
result = {}
|
||||
for category in EvaluationCategory:
|
||||
result[category.value] = EvaluationService.get_supported_metrics(category)
|
||||
return {"metrics": result}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
|
||||
class EvaluationDefaultMetricsApi(Resource):
|
||||
@console_ns.doc(
|
||||
"get_evaluation_default_metrics_with_nodes",
|
||||
description=(
|
||||
"List default metrics supported by the current evaluation framework with matching nodes "
|
||||
"from the target's published workflow only (draft is ignored)."
|
||||
),
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Default metrics and node candidates for the published workflow",
|
||||
evaluation_default_metrics_response_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
return {"default_metrics": [m.model_dump() for m in default_metrics]}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
|
||||
class EvaluationNodeInfoApi(Resource):
|
||||
@console_ns.doc("get_evaluation_node_info")
|
||||
@console_ns.response(200, "Node info grouped by metric")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""Return workflow/snippet node info grouped by requested metrics.
|
||||
|
||||
Request body (JSON):
|
||||
- metrics: list[str] | None – metric names to query; omit or pass
|
||||
an empty list to get all nodes under key ``"all"``.
|
||||
|
||||
Response:
|
||||
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
|
||||
"""
|
||||
body = request.get_json(silent=True) or {}
|
||||
metrics: list[str] | None = body.get("metrics") or None
|
||||
|
||||
result = EvaluationService.get_nodes_for_metrics(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
metrics=metrics,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/evaluation/available-metrics")
|
||||
class EvaluationAvailableMetricsApi(Resource):
|
||||
@console_ns.doc("get_available_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics list")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Return the centrally-defined list of evaluation metrics."""
|
||||
return {"metrics": EvaluationService.get_available_metrics()}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||
class EvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated successfully")
|
||||
@console_ns.response(404, "Target or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||
class EvaluationVersionApi(Resource):
|
||||
@console_ns.doc("get_evaluation_version_detail")
|
||||
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||
@console_ns.response(200, "Version details retrieved successfully")
|
||||
@console_ns.response(404, "Target or version not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation target version details.
|
||||
|
||||
Returns the workflow graph for the specified version.
|
||||
"""
|
||||
version = request.args.get("version")
|
||||
|
||||
if not version:
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
graph = {}
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/available-evaluation-workflows")
|
||||
class AvailableEvaluationWorkflowsApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("list_available_evaluation_workflows")
|
||||
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Available evaluation workflows retrieved",
|
||||
available_evaluation_workflow_pagination_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self):
|
||||
"""List published evaluation-type workflows for the current tenant (cross-app)."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
keyword = args.keyword
|
||||
|
||||
if user_id and user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.list_published_evaluation_workflows(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
keyword=keyword,
|
||||
)
|
||||
|
||||
app_ids = {w.app_id for w in workflows}
|
||||
if app_ids:
|
||||
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
else:
|
||||
app_names = {}
|
||||
|
||||
items = []
|
||||
for wf in workflows:
|
||||
items.append(
|
||||
{
|
||||
"id": wf.id,
|
||||
"app_id": wf.app_id,
|
||||
"app_name": app_names.get(wf.app_id, ""),
|
||||
"type": wf.type.value,
|
||||
"version": wf.version,
|
||||
"marked_name": wf.marked_name,
|
||||
"marked_comment": wf.marked_comment,
|
||||
"hash": wf.unique_hash,
|
||||
"created_by": wf.created_by_account,
|
||||
"created_at": wf.created_at,
|
||||
"updated_by": wf.updated_by_account,
|
||||
"updated_at": wf.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
marshal(
|
||||
{"items": items, "page": page, "limit": limit, "has_more": has_more},
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
|
||||
class EvaluationWorkflowAssociatedTargetsApi(Resource):
|
||||
@console_ns.doc("list_evaluation_workflow_associated_targets")
|
||||
@console_ns.doc(
|
||||
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, workflow_id: str):
|
||||
"""Return all evaluation targets that reference this workflow as customized metrics."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
configs = EvaluationService.list_targets_by_customized_workflow(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
customized_workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
target_ids_by_type: dict[str, list[str]] = {}
|
||||
for cfg in configs:
|
||||
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
|
||||
|
||||
app_names: dict[str, str] = {}
|
||||
if "app" in target_ids_by_type:
|
||||
apps = session.scalars(select(App).where(App.id.in_(target_ids_by_type["app"]))).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
|
||||
snippet_names: dict[str, str] = {}
|
||||
if "snippets" in target_ids_by_type:
|
||||
snippets = session.scalars(
|
||||
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
|
||||
).all()
|
||||
snippet_names = {s.id: s.name for s in snippets}
|
||||
|
||||
dataset_names: dict[str, str] = {}
|
||||
if "knowledge_base" in target_ids_by_type:
|
||||
datasets = session.scalars(
|
||||
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
|
||||
).all()
|
||||
dataset_names = {d.id: d.name for d in datasets}
|
||||
|
||||
items = []
|
||||
for cfg in configs:
|
||||
name = ""
|
||||
if cfg.target_type == "app":
|
||||
name = app_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "snippets":
|
||||
name = snippet_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "knowledge_base":
|
||||
name = dataset_names.get(cfg.target_id, "")
|
||||
|
||||
items.append(
|
||||
{
|
||||
"target_type": cfg.target_type,
|
||||
"target_id": cfg.target_id,
|
||||
"target_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": items}, 200
|
||||
|
||||
|
||||
# ---- Serialization Helpers ----
|
||||
|
||||
|
||||
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": run.metrics_summary_dict,
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
135
api/controllers/console/snippets/payloads.py
Normal file
135
api/controllers/console/snippets/payloads.py
Normal file
@ -0,0 +1,135 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(creator).strip() for creator in value if str(creator).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
534
api/controllers/console/snippets/snippet_workflow.py
Normal file
534
api/controllers/console/snippets/snippet_workflow.py
Normal file
@ -0,0 +1,534 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import workflow_model
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
workflow_run_node_execution_model,
|
||||
workflow_run_pagination_model,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet(view_func: Callable[P, R]):
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
db.session.expunge(workflow)
|
||||
workflow.conversation_variables = []
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
input_fields=payload.input_fields,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = SnippetService()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = SnippetService()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = SnippetService()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -0,0 +1,319 @@
|
||||
"""
|
||||
Snippet draft workflow variable APIs.
|
||||
|
||||
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
_ensure_variable_access,
|
||||
_file_access_controller,
|
||||
validate_node_id,
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
|
||||
|
||||
def _ensure_snippet_draft_variable_row_allowed(
|
||||
*,
|
||||
variable: WorkflowDraftVariable,
|
||||
variable_id: str,
|
||||
) -> None:
|
||||
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_snippet_workflow_variables")
|
||||
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow variables retrieved successfully",
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
snippet_service = SnippetService()
|
||||
if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=snippet.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class SnippetVariableApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class SnippetVariableResetApi(Resource):
|
||||
@console_ns.doc("reset_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
class SnippetConversationVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_conversation_variables")
|
||||
@console_ns.doc(
|
||||
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
class SnippetSystemVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_system_variables")
|
||||
@console_ns.doc(
|
||||
description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars_list: list[dict[str, Any]] = []
|
||||
for v in workflow.environment_variables:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
380
api/controllers/console/workspace/snippets.py
Normal file
380
api/controllers/console/workspace/snippets.py
Normal file
@ -0,0 +1,380 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
IncludeSecretQuery,
|
||||
SnippetImportPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import SnippetType
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.snippet_dsl_service import SnippetDslService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
SnippetImportPayload,
|
||||
IncludeSecretQuery,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query_params = request.args.to_dict()
|
||||
query = SnippetListQuery.model_validate(query_params)
|
||||
|
||||
snippets, total, has_more = SnippetService.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
is_published=query.is_published,
|
||||
creators=query.creators,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
|
||||
class CustomizedSnippetExportApi(Resource):
|
||||
@console_ns.doc("export_customized_snippet")
|
||||
@console_ns.doc(description="Export snippet configuration as DSL")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
|
||||
@console_ns.response(200, "Snippet exported successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Export snippet as DSL."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
# Get include_secret parameter
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = SnippetDslService(session)
|
||||
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
|
||||
|
||||
# Set filename with .snippet extension
|
||||
filename = f"{snippet.name}.snippet"
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
response = Response(
|
||||
result,
|
||||
mimetype="application/x-yaml",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/x-yaml"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports")
|
||||
class CustomizedSnippetImportApi(Resource):
|
||||
@console_ns.doc("import_customized_snippet")
|
||||
@console_ns.doc(description="Import snippet from DSL")
|
||||
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
|
||||
@console_ns.response(200, "Snippet imported successfully")
|
||||
@console_ns.response(202, "Import pending confirmation")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Import snippet from DSL."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.import_snippet(
|
||||
account=current_user,
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
snippet_id=payload.snippet_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
|
||||
class CustomizedSnippetImportConfirmApi(Resource):
|
||||
@console_ns.doc("confirm_snippet_import")
|
||||
@console_ns.doc(description="Confirm a pending snippet import")
|
||||
@console_ns.doc(params={"import_id": "Import ID to confirm"})
|
||||
@console_ns.response(200, "Import confirmed successfully")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
"""Confirm a pending snippet import."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.confirm_import(import_id=import_id, account=current_user)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
|
||||
class CustomizedSnippetCheckDependenciesApi(Resource):
|
||||
@console_ns.doc("check_snippet_dependencies")
|
||||
@console_ns.doc(description="Check dependencies for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Dependencies checked successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Check dependencies for a snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.check_dependencies(snippet=snippet)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
|
||||
class CustomizedSnippetUseCountIncrementApi(Resource):
|
||||
@console_ns.doc("increment_snippet_use_count")
|
||||
@console_ns.doc(description="Increment snippet use count by 1")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Use count incremented successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, snippet_id: str):
|
||||
"""Increment snippet use count when it is inserted into a workflow."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.increment_use_count(session=session, snippet=snippet)
|
||||
session.commit()
|
||||
session.refresh(snippet)
|
||||
|
||||
return {"result": "success", "use_count": snippet.use_count}, 200
|
||||
@ -14,7 +14,7 @@ from graphon.runtime import GraphRuntimeState
|
||||
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -54,6 +54,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
|
||||
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
|
||||
if workflow.type != "snippet":
|
||||
return workflow
|
||||
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == workflow.app_id,
|
||||
CustomizedSnippet.tenant_id == workflow.tenant_id,
|
||||
)
|
||||
)
|
||||
if snippet is None:
|
||||
return workflow
|
||||
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
|
||||
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
@ -557,6 +576,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
|
||||
0
api/core/evaluation/__init__.py
Normal file
0
api/core/evaluation/__init__.py
Normal file
279
api/core/evaluation/base_evaluation_instance.py
Normal file
279
api/core/evaluation/base_evaluation_instance.py
Normal file
@ -0,0 +1,279 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
CustomizedMetrics,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
NodeInfo,
|
||||
)
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationInstance(ABC):
|
||||
"""Abstract base class for evaluation framework adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate LLM outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate retrieval quality using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate agent outputs using the configured framework."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
"""Return the list of supported metric names for a given evaluation category."""
|
||||
...
|
||||
|
||||
def evaluate_with_customized_workflow(
|
||||
self,
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
|
||||
customized_metrics: CustomizedMetrics,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using a published workflow as the evaluator.
|
||||
|
||||
The evaluator workflow's output variables are treated as metrics:
|
||||
each output variable name becomes a metric name, and its value
|
||||
becomes the score.
|
||||
|
||||
Args:
|
||||
node_run_result_mapping_list: One mapping per test-data item,
|
||||
where each mapping is ``{node_id: NodeRunResult}`` from the
|
||||
target execution.
|
||||
customized_metrics: Contains ``evaluation_workflow_id`` (the
|
||||
published evaluator workflow) and ``input_fields`` (value
|
||||
sources for the evaluator's input variables).
|
||||
tenant_id: Tenant scope.
|
||||
|
||||
Returns:
|
||||
A list of ``EvaluationItemResult`` with metrics extracted from
|
||||
the evaluator workflow's output variables.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.evaluation.runners import get_service_account_for_app
|
||||
from models.engine import db
|
||||
from models.model import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
workflow_id = customized_metrics.evaluation_workflow_id
|
||||
if not workflow_id:
|
||||
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
|
||||
|
||||
# Load the evaluator workflow resources using a dedicated session
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
|
||||
service_account = get_service_account_for_app(session, workflow_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
published_workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not published_workflow:
|
||||
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
|
||||
|
||||
eval_results: list[EvaluationItemResult] = []
|
||||
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
|
||||
try:
|
||||
workflow_inputs = self._build_workflow_inputs(
|
||||
customized_metrics.input_fields,
|
||||
node_run_result_mapping,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
response: Mapping[str, Any] = generator.generate(
|
||||
app_model=app,
|
||||
workflow=published_workflow,
|
||||
user=service_account,
|
||||
args={"inputs": workflow_inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
metrics = self._extract_workflow_metrics(response, workflow_id)
|
||||
eval_results.append(
|
||||
EvaluationItemResult(
|
||||
index=idx,
|
||||
metrics=metrics,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Customized evaluator failed for item %d with workflow %s",
|
||||
idx,
|
||||
workflow_id,
|
||||
)
|
||||
eval_results.append(EvaluationItemResult(index=idx))
|
||||
|
||||
return eval_results
|
||||
|
||||
@staticmethod
|
||||
def _build_workflow_inputs(
|
||||
input_fields: dict[str, Any],
|
||||
node_run_result_mapping: dict[str, NodeRunResult],
|
||||
) -> dict[str, Any]:
|
||||
"""Build customized workflow inputs by resolving value sources.
|
||||
|
||||
Each entry in ``input_fields`` maps a workflow input variable name
|
||||
to its value source, which can be:
|
||||
|
||||
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
|
||||
- **Expression**: a string containing one or more
|
||||
``{{#node_id.output_key#}}`` selectors (same format as
|
||||
``VariableTemplateParser``) resolved from
|
||||
``node_run_result_mapping``.
|
||||
|
||||
"""
|
||||
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
|
||||
|
||||
workflow_inputs: dict[str, Any] = {}
|
||||
|
||||
for field_name, value_source in input_fields.items():
|
||||
if not isinstance(value_source, str):
|
||||
# Non-string values (numbers, bools, dicts) are used directly.
|
||||
workflow_inputs[field_name] = value_source
|
||||
continue
|
||||
|
||||
# Check if the entire value is a single expression.
|
||||
full_match = VARIABLE_REGEX.fullmatch(value_source)
|
||||
if full_match:
|
||||
workflow_inputs[field_name] = resolve_variable_selector(
|
||||
full_match.group(1),
|
||||
node_run_result_mapping,
|
||||
)
|
||||
elif VARIABLE_REGEX.search(value_source):
|
||||
# Mixed template: interpolate all expressions as strings.
|
||||
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
|
||||
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
|
||||
value_source,
|
||||
)
|
||||
else:
|
||||
# Plain constant — no expression markers.
|
||||
workflow_inputs[field_name] = value_source
|
||||
|
||||
return workflow_inputs
|
||||
|
||||
@staticmethod
|
||||
def _extract_workflow_metrics(
|
||||
response: Mapping[str, object],
|
||||
evaluation_workflow_id: str,
|
||||
) -> list[EvaluationMetric]:
|
||||
"""Extract evaluation metrics from workflow output variables.
|
||||
|
||||
Each metric's ``node_info`` is set with *evaluation_workflow_id* as
|
||||
the ``node_id``, so that judgment conditions can reference customized
|
||||
metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``.
|
||||
"""
|
||||
metrics: list[EvaluationMetric] = []
|
||||
node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized")
|
||||
|
||||
data = response.get("data")
|
||||
if not isinstance(data, Mapping):
|
||||
logger.warning("Unexpected workflow response format: missing 'data' dict")
|
||||
return metrics
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if not isinstance(outputs, dict):
|
||||
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
|
||||
return metrics
|
||||
|
||||
for key, raw_value in outputs.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info))
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def resolve_variable_selector(
|
||||
selector_raw: str,
|
||||
node_run_result_mapping: dict[str, NodeRunResult],
|
||||
) -> object:
|
||||
"""
|
||||
Resolve a ``#node_id.output_key#`` selector against node run results.
|
||||
"""
|
||||
#
|
||||
cleaned = selector_raw.strip("#")
|
||||
parts = cleaned.split(".")
|
||||
|
||||
if len(parts) < 2:
|
||||
logger.warning(
|
||||
"Selector '%s' must have at least node_id.output_key",
|
||||
selector_raw,
|
||||
)
|
||||
return ""
|
||||
|
||||
node_id = parts[0]
|
||||
output_path = parts[1:]
|
||||
|
||||
node_result = node_run_result_mapping.get(node_id)
|
||||
if not node_result or not node_result.outputs:
|
||||
logger.warning(
|
||||
"Selector '%s': node '%s' not found or has no outputs",
|
||||
selector_raw,
|
||||
node_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
# Traverse the output path to support nested keys.
|
||||
current: object = node_result.outputs
|
||||
for key in output_path:
|
||||
if isinstance(current, Mapping):
|
||||
next_val = current.get(key)
|
||||
if next_val is None:
|
||||
logger.warning(
|
||||
"Selector '%s': key '%s' not found in node '%s' outputs",
|
||||
selector_raw,
|
||||
key,
|
||||
node_id,
|
||||
)
|
||||
return ""
|
||||
current = next_val
|
||||
else:
|
||||
logger.warning(
|
||||
"Selector '%s': cannot traverse into non-dict value at key '%s'",
|
||||
selector_raw,
|
||||
key,
|
||||
)
|
||||
return ""
|
||||
|
||||
return current if current is not None else ""
|
||||
0
api/core/evaluation/entities/__init__.py
Normal file
0
api/core/evaluation/entities/__init__.py
Normal file
27
api/core/evaluation/entities/config_entity.py
Normal file
27
api/core/evaluation/entities/config_entity.py
Normal file
@ -0,0 +1,27 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EvaluationFrameworkEnum(StrEnum):
|
||||
RAGAS = "ragas"
|
||||
DEEPEVAL = "deepeval"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class BaseEvaluationConfig(BaseModel):
|
||||
"""Base configuration for evaluation frameworks."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RagasConfig(BaseEvaluationConfig):
|
||||
"""RAGAS-specific configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepEvalConfig(BaseEvaluationConfig):
|
||||
"""DeepEval-specific configuration."""
|
||||
|
||||
pass
|
||||
226
api/core/evaluation/entities/evaluation_entity.py
Normal file
226
api/core/evaluation/entities/evaluation_entity.py
Normal file
@ -0,0 +1,226 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
|
||||
|
||||
|
||||
class EvaluationCategory(StrEnum):
|
||||
LLM = "llm"
|
||||
RETRIEVAL = "knowledge_retrieval"
|
||||
AGENT = "agent"
|
||||
WORKFLOW = "workflow"
|
||||
SNIPPET = "snippet"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
|
||||
class EvaluationMetricName(StrEnum):
|
||||
"""Canonical metric names shared across all evaluation frameworks.
|
||||
|
||||
Each framework maps these names to its own internal implementation.
|
||||
A framework that does not support a given metric should log a warning
|
||||
and skip it rather than raising an error.
|
||||
|
||||
── LLM / general text-quality metrics ──────────────────────────────────
|
||||
FAITHFULNESS
|
||||
Measures whether every claim in the model's response is grounded in
|
||||
the provided retrieved context. A high score means the answer
|
||||
contains no hallucinated content — each statement can be traced back
|
||||
to a passage in the context.
|
||||
Required fields: user_input, response, retrieved_contexts.
|
||||
|
||||
ANSWER_RELEVANCY
|
||||
Measures how well the model's response addresses the user's question.
|
||||
A high score means the answer stays on-topic; a low score indicates
|
||||
irrelevant content or a failure to answer the actual question.
|
||||
Required fields: user_input, response.
|
||||
|
||||
ANSWER_CORRECTNESS
|
||||
Measures the factual accuracy and completeness of the model's answer
|
||||
relative to a ground-truth reference. It combines semantic similarity
|
||||
with key-fact coverage, so both meaning and content matter.
|
||||
Required fields: user_input, response, reference (expected_output).
|
||||
|
||||
SEMANTIC_SIMILARITY
|
||||
Measures the cosine similarity between the model's response and the
|
||||
reference answer in an embedding space. It evaluates whether the two
|
||||
texts convey the same meaning, independent of factual correctness.
|
||||
Required fields: response, reference (expected_output).
|
||||
|
||||
── Retrieval-quality metrics ────────────────────────────────────────────
|
||||
CONTEXT_PRECISION
|
||||
Measures the proportion of retrieved context chunks that are actually
|
||||
relevant to the question (precision). A high score means the retrieval
|
||||
pipeline returns little noise.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RECALL
|
||||
Measures the proportion of ground-truth information that is covered by
|
||||
the retrieved context chunks (recall). A high score means the retrieval
|
||||
pipeline does not miss important supporting evidence.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RELEVANCE
|
||||
Measures how relevant each individual retrieved chunk is to the query.
|
||||
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
|
||||
than against a reference answer.
|
||||
Required fields: user_input, retrieved_contexts.
|
||||
|
||||
── Agent-quality metrics ────────────────────────────────────────────────
|
||||
TOOL_CORRECTNESS
|
||||
Measures the correctness of the tool calls made by the agent during
|
||||
task execution — both the choice of tool and the arguments passed.
|
||||
A high score means the agent's tool-use strategy matches the expected
|
||||
behavior.
|
||||
Required fields: actual tool calls vs. expected tool calls.
|
||||
|
||||
TASK_COMPLETION
|
||||
Measures whether the agent ultimately achieves the user's stated goal.
|
||||
It evaluates the reasoning chain, intermediate steps, and final output
|
||||
holistically; a high score means the task was fully accomplished.
|
||||
Required fields: user_input, actual_output.
|
||||
"""
|
||||
|
||||
# LLM / general text-quality metrics
|
||||
FAITHFULNESS = "faithfulness"
|
||||
ANSWER_RELEVANCY = "answer_relevancy"
|
||||
ANSWER_CORRECTNESS = "answer_correctness"
|
||||
SEMANTIC_SIMILARITY = "semantic_similarity"
|
||||
|
||||
# Retrieval-quality metrics
|
||||
CONTEXT_PRECISION = "context_precision"
|
||||
CONTEXT_RECALL = "context_recall"
|
||||
CONTEXT_RELEVANCE = "context_relevance"
|
||||
|
||||
# Agent-quality metrics
|
||||
TOOL_CORRECTNESS = "tool_correctness"
|
||||
TASK_COMPLETION = "task_completion"
|
||||
|
||||
|
||||
# Per-category canonical metric lists used by get_supported_metrics().
|
||||
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
|
||||
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
|
||||
]
|
||||
|
||||
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
|
||||
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
|
||||
]
|
||||
|
||||
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
|
||||
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
|
||||
]
|
||||
|
||||
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY,
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS,
|
||||
]
|
||||
|
||||
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
|
||||
**{m.value: "llm" for m in LLM_METRIC_NAMES},
|
||||
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
|
||||
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
|
||||
}
|
||||
|
||||
METRIC_VALUE_TYPE_MAPPING: dict[str, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "number",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "number",
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: "number",
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: "number",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "number",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "number",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "number",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "number",
|
||||
EvaluationMetricName.TASK_COMPLETION: "number",
|
||||
}
|
||||
|
||||
|
||||
class NodeInfo(BaseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class EvaluationMetric(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
details: dict[str, Any] = Field(default_factory=dict)
|
||||
node_info: NodeInfo | None = None
|
||||
|
||||
|
||||
class EvaluationItemInput(BaseModel):
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
output: str
|
||||
expected_output: str | None = None
|
||||
context: list[str] | None = None
|
||||
|
||||
|
||||
class EvaluationDatasetInput(BaseModel):
|
||||
index: int
|
||||
inputs: dict[str, Any]
|
||||
expected_output: str | None = None
|
||||
|
||||
|
||||
class EvaluationItemResult(BaseModel):
|
||||
index: int
|
||||
actual_output: str | None = None
|
||||
metrics: list[EvaluationMetric] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class DefaultMetric(BaseModel):
|
||||
metric: str
|
||||
value_type: str = ""
|
||||
node_info_list: list[NodeInfo]
|
||||
|
||||
|
||||
class CustomizedMetricOutputField(BaseModel):
|
||||
variable: str
|
||||
value_type: str
|
||||
|
||||
|
||||
class CustomizedMetrics(BaseModel):
|
||||
evaluation_workflow_id: str
|
||||
input_fields: dict[str, Any]
|
||||
output_fields: list[CustomizedMetricOutputField]
|
||||
|
||||
|
||||
class EvaluationConfigData(BaseModel):
|
||||
"""Structured data for saving evaluation configuration."""
|
||||
|
||||
evaluation_model: str = ""
|
||||
evaluation_model_provider: str = ""
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
|
||||
|
||||
class EvaluationRunRequest(EvaluationConfigData):
|
||||
"""Request body for starting an evaluation run."""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class EvaluationRunData(BaseModel):
|
||||
"""Serializable data for Celery task."""
|
||||
|
||||
evaluation_run_id: str
|
||||
tenant_id: str
|
||||
target_type: str
|
||||
target_id: str
|
||||
evaluation_model_provider: str
|
||||
evaluation_model: str
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
customized_metrics: CustomizedMetrics | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
input_list: list[EvaluationDatasetInput]
|
||||
96
api/core/evaluation/entities/judgment_entity.py
Normal file
96
api/core/evaluation/entities/judgment_entity.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Judgment condition entities for evaluation metric assessment.
|
||||
|
||||
Condition structure mirrors the workflow if-else ``Condition`` model from
|
||||
``graphon.utils.condition.entities``. The left-hand side uses
|
||||
``variable_selector`` — a two-element list ``[node_id, metric_name]`` — to
|
||||
uniquely identify an evaluation metric (different nodes may produce metrics
|
||||
with the same name).
|
||||
|
||||
Operators reuse ``SupportedComparisonOperator`` from the workflow engine so
|
||||
that type semantics stay consistent across the platform.
|
||||
|
||||
Typical usage::
|
||||
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["node_abc", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
|
||||
|
||||
class JudgmentCondition(BaseModel):
|
||||
"""A single judgment condition that checks one metric value.
|
||||
|
||||
Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand
|
||||
side being a metric selector instead of a workflow variable selector.
|
||||
|
||||
Attributes:
|
||||
variable_selector: ``[node_id, metric_name]`` identifying the metric.
|
||||
comparison_operator: Reuses workflow's ``SupportedComparisonOperator``.
|
||||
value: The comparison target (right side). For unary operators such
|
||||
as ``empty`` or ``null`` this can be ``None``.
|
||||
"""
|
||||
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
|
||||
|
||||
class JudgmentConfig(BaseModel):
|
||||
"""A group of judgment conditions combined with a logical operator.
|
||||
|
||||
Attributes:
|
||||
logical_operator: How to combine condition results — "and" requires
|
||||
all conditions to pass, "or" requires at least one.
|
||||
conditions: The list of individual conditions to evaluate.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[JudgmentCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class JudgmentConditionResult(BaseModel):
|
||||
"""Result of evaluating a single judgment condition.
|
||||
|
||||
Attributes:
|
||||
variable_selector: ``[node_id, metric_name]`` that was checked.
|
||||
comparison_operator: The operator that was applied.
|
||||
expected_value: The resolved comparison value.
|
||||
actual_value: The actual metric value that was evaluated.
|
||||
passed: Whether this individual condition passed.
|
||||
error: Error message if the condition evaluation failed.
|
||||
"""
|
||||
|
||||
variable_selector: list[str]
|
||||
comparison_operator: str
|
||||
expected_value: Any = None
|
||||
actual_value: Any = None
|
||||
passed: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class JudgmentResult(BaseModel):
|
||||
"""Overall result of evaluating all judgment conditions for one item.
|
||||
|
||||
Attributes:
|
||||
passed: Whether the overall judgment passed (based on logical_operator).
|
||||
logical_operator: The logical operator used to combine conditions.
|
||||
condition_results: Detailed result for each individual condition.
|
||||
"""
|
||||
|
||||
passed: bool = False
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)
|
||||
61
api/core/evaluation/evaluation_manager.py
Normal file
61
api/core/evaluation/evaluation_manager.py
Normal file
@ -0,0 +1,61 @@
|
||||
import collections
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
|
||||
|
||||
def __getitem__(self, framework: str) -> dict[str, Any]:
|
||||
match framework:
|
||||
case EvaluationFrameworkEnum.RAGAS:
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
|
||||
|
||||
return {
|
||||
"config_class": RagasConfig,
|
||||
"evaluator_class": RagasEvaluator,
|
||||
}
|
||||
case EvaluationFrameworkEnum.DEEPEVAL:
|
||||
raise NotImplementedError("DeepEval adapter is not yet implemented.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown evaluation framework: {framework}")
|
||||
|
||||
|
||||
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
|
||||
|
||||
|
||||
class EvaluationManager:
|
||||
"""Factory for evaluation instances based on global configuration."""
|
||||
|
||||
@staticmethod
|
||||
def get_evaluation_instance() -> BaseEvaluationInstance | None:
|
||||
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
|
||||
framework = dify_config.EVALUATION_FRAMEWORK
|
||||
if not framework or framework == EvaluationFrameworkEnum.NONE:
|
||||
return None
|
||||
|
||||
try:
|
||||
config_map = evaluation_framework_config_map[framework]
|
||||
evaluator_class = config_map["evaluator_class"]
|
||||
config_class = config_map["config_class"]
|
||||
config = config_class()
|
||||
return evaluator_class(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create evaluation instance for framework: %s", framework)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
|
||||
"""Return supported metrics for the current framework and given category."""
|
||||
instance = EvaluationManager.get_evaluation_instance()
|
||||
if instance is None:
|
||||
return []
|
||||
return instance.get_supported_metrics(category)
|
||||
0
api/core/evaluation/frameworks/__init__.py
Normal file
0
api/core/evaluation/frameworks/__init__.py
Normal file
1
api/core/evaluation/frameworks/deepeval/__init__.py
Normal file
1
api/core/evaluation/frameworks/deepeval/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
299
api/core/evaluation/frameworks/deepeval/deepeval_evaluator.py
Normal file
299
api/core/evaluation/frameworks/deepeval/deepeval_evaluator.py
Normal file
@ -0,0 +1,299 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import DeepEvalConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
AGENT_METRIC_NAMES,
|
||||
LLM_METRIC_NAMES,
|
||||
RETRIEVAL_METRIC_NAMES,
|
||||
WORKFLOW_METRIC_NAMES,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
EvaluationMetricName,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
|
||||
# deepeval metric field requirements (LLMTestCase fields):
|
||||
# - faithfulness: input, actual_output, retrieval_context
|
||||
# - answer_relevancy: input, actual_output
|
||||
# - context_precision: input, actual_output, expected_output, retrieval_context
|
||||
# - context_recall: input, actual_output, expected_output, retrieval_context
|
||||
# - context_relevance: input, actual_output, retrieval_context
|
||||
# - tool_correctness: input, actual_output, expected_tools
|
||||
# - task_completion: input, actual_output
|
||||
# Metrics not listed here are unsupported by deepeval and will be skipped.
|
||||
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
|
||||
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
|
||||
}
|
||||
|
||||
|
||||
class DeepEvalEvaluator(BaseEvaluationInstance):
|
||||
"""DeepEval framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: DeepEvalConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
candidates = LLM_METRIC_NAMES
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
candidates = RETRIEVAL_METRIC_NAMES
|
||||
case EvaluationCategory.AGENT:
|
||||
candidates = AGENT_METRIC_NAMES
|
||||
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
|
||||
candidates = WORKFLOW_METRIC_NAMES
|
||||
case _:
|
||||
return []
|
||||
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using DeepEval."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_deepeval(items, requested_metrics, category)
|
||||
except ImportError:
|
||||
logger.warning("DeepEval not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_deepeval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using DeepEval library.
|
||||
|
||||
Builds LLMTestCase differently per category:
|
||||
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
|
||||
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
|
||||
- Agent: input=query, actual_output=output
|
||||
"""
|
||||
metric_pairs = _build_deepeval_metrics(requested_metrics)
|
||||
if not metric_pairs:
|
||||
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index) for item in items]
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
test_case = self._build_test_case(item, category)
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for canonical_name, metric in metric_pairs:
|
||||
try:
|
||||
metric.measure(test_case)
|
||||
if metric.score is not None:
|
||||
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to compute metric %s for item %d",
|
||||
canonical_name,
|
||||
item.index,
|
||||
)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
|
||||
"""Build a deepeval LLMTestCase with the correct fields per category."""
|
||||
from deepeval.test_case import LLMTestCase
|
||||
|
||||
user_input = _format_input(item.inputs, category)
|
||||
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
# faithfulness needs: input, actual_output, retrieval_context
|
||||
# answer_relevancy needs: input, actual_output
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output,
|
||||
expected_output=item.expected_output or None,
|
||||
retrieval_context=item.context or None,
|
||||
)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output or "",
|
||||
expected_output=item.expected_output or "",
|
||||
retrieval_context=item.context or [],
|
||||
)
|
||||
case _:
|
||||
return LLMTestCase(
|
||||
input=user_input,
|
||||
actual_output=item.output,
|
||||
)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when DeepEval is not available."""
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, m_name, item)
|
||||
metrics.append(EvaluationMetric(name=m_name, value=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
@staticmethod
|
||||
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nInput: {item.inputs}",
|
||||
f"\nOutput: {item.output}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
|
||||
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
|
||||
"""Extract the user-facing input string from the inputs dict."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
return str(inputs.get("prompt", ""))
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return str(inputs.get("query", ""))
|
||||
case _:
|
||||
return str(next(iter(inputs.values()), "")) if inputs else ""
|
||||
|
||||
|
||||
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
|
||||
"""Build DeepEval metric instances from canonical metric names.
|
||||
|
||||
Returns a list of (canonical_name, metric_instance) pairs so that callers
|
||||
can record the canonical name rather than the framework-internal class name.
|
||||
"""
|
||||
try:
|
||||
from deepeval.metrics import (
|
||||
AnswerRelevancyMetric,
|
||||
ContextualPrecisionMetric,
|
||||
ContextualRecallMetric,
|
||||
ContextualRelevancyMetric,
|
||||
FaithfulnessMetric,
|
||||
TaskCompletionMetric,
|
||||
ToolCorrectnessMetric,
|
||||
)
|
||||
|
||||
# Maps canonical name → deepeval metric class
|
||||
deepeval_class_map: dict[str, Any] = {
|
||||
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
|
||||
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
|
||||
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
|
||||
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
|
||||
}
|
||||
|
||||
pairs: list[tuple[str, Any]] = []
|
||||
for name in requested_metrics:
|
||||
metric_class = deepeval_class_map.get(name)
|
||||
if metric_class:
|
||||
pairs.append((name, metric_class(threshold=0.5)))
|
||||
else:
|
||||
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
|
||||
return pairs
|
||||
except ImportError:
|
||||
logger.warning("DeepEval metrics not available")
|
||||
return []
|
||||
0
api/core/evaluation/frameworks/ragas/__init__.py
Normal file
0
api/core/evaluation/frameworks/ragas/__init__.py
Normal file
312
api/core/evaluation/frameworks/ragas/ragas_evaluator.py
Normal file
312
api/core/evaluation/frameworks/ragas/ragas_evaluator.py
Normal file
@ -0,0 +1,312 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.config_entity import RagasConfig
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
AGENT_METRIC_NAMES,
|
||||
LLM_METRIC_NAMES,
|
||||
RETRIEVAL_METRIC_NAMES,
|
||||
WORKFLOW_METRIC_NAMES,
|
||||
EvaluationCategory,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationMetric,
|
||||
EvaluationMetricName,
|
||||
)
|
||||
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
|
||||
# Metrics not listed here are unsupported by ragas and will be skipped.
|
||||
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
|
||||
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
|
||||
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
|
||||
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
|
||||
}
|
||||
|
||||
|
||||
class RagasEvaluator(BaseEvaluationInstance):
|
||||
"""RAGAS framework adapter for evaluation."""
|
||||
|
||||
def __init__(self, config: RagasConfig):
|
||||
self.config = config
|
||||
|
||||
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
candidates = LLM_METRIC_NAMES
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
candidates = RETRIEVAL_METRIC_NAMES
|
||||
case EvaluationCategory.AGENT:
|
||||
candidates = AGENT_METRIC_NAMES
|
||||
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
|
||||
candidates = WORKFLOW_METRIC_NAMES
|
||||
case _:
|
||||
return []
|
||||
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
|
||||
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using RAGAS."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
|
||||
except ImportError:
|
||||
logger.warning("RAGAS not installed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
def _evaluate_with_ragas(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
category: EvaluationCategory,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate using RAGAS library.
|
||||
|
||||
Builds SingleTurnSample differently per category to match ragas requirements:
|
||||
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
|
||||
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
|
||||
- Agent: Not supported via EvaluationDataset (requires message-based API)
|
||||
"""
|
||||
from ragas import evaluate as ragas_evaluate
|
||||
from ragas.dataset_schema import EvaluationDataset
|
||||
|
||||
samples: list[Any] = []
|
||||
for item in items:
|
||||
sample = self._build_sample(item, category)
|
||||
samples.append(sample)
|
||||
|
||||
dataset = EvaluationDataset(samples=samples)
|
||||
|
||||
ragas_metrics = self._build_ragas_metrics(requested_metrics)
|
||||
if not ragas_metrics:
|
||||
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
|
||||
return [EvaluationItemResult(index=item.index) for item in items]
|
||||
|
||||
try:
|
||||
result = ragas_evaluate(
|
||||
dataset=dataset,
|
||||
metrics=ragas_metrics,
|
||||
)
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
result_df = result.to_pandas()
|
||||
for i, item in enumerate(items):
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
if m_name in result_df.columns:
|
||||
score = result_df.iloc[i][m_name]
|
||||
if score is not None and not (isinstance(score, float) and score != score):
|
||||
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
except Exception:
|
||||
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
|
||||
return self._evaluate_simple(items, requested_metrics, model_wrapper)
|
||||
|
||||
@staticmethod
|
||||
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
|
||||
"""Build a ragas SingleTurnSample with the correct fields per category.
|
||||
|
||||
ragas metric field requirements:
|
||||
- faithfulness: user_input, response, retrieved_contexts
|
||||
- answer_relevancy: user_input, response
|
||||
- answer_correctness: user_input, response, reference
|
||||
- semantic_similarity: user_input, response, reference
|
||||
- context_precision: user_input, reference, retrieved_contexts
|
||||
- context_recall: user_input, reference, retrieved_contexts
|
||||
- context_relevance: user_input, retrieved_contexts
|
||||
"""
|
||||
from ragas.dataset_schema import SingleTurnSample
|
||||
|
||||
user_input = _format_input(item.inputs, category)
|
||||
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
# response = actual LLM output, reference = expected output
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
response=item.output,
|
||||
reference=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
# context_precision/recall only need reference + retrieved_contexts
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
reference=item.expected_output or "",
|
||||
retrieved_contexts=item.context or [],
|
||||
)
|
||||
case _:
|
||||
return SingleTurnSample(
|
||||
user_input=user_input,
|
||||
response=item.output,
|
||||
)
|
||||
|
||||
def _evaluate_simple(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
requested_metrics: list[str],
|
||||
model_wrapper: DifyModelWrapper,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Simple LLM-as-judge fallback when RAGAS is not available."""
|
||||
results: list[EvaluationItemResult] = []
|
||||
for item in items:
|
||||
metrics: list[EvaluationMetric] = []
|
||||
for m_name in requested_metrics:
|
||||
try:
|
||||
score = self._judge_with_llm(model_wrapper, m_name, item)
|
||||
metrics.append(EvaluationMetric(name=m_name, value=score))
|
||||
except Exception:
|
||||
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
|
||||
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
|
||||
return results
|
||||
|
||||
def _judge_with_llm(
|
||||
self,
|
||||
model_wrapper: DifyModelWrapper,
|
||||
metric_name: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> float:
|
||||
"""Use the LLM to judge a single metric for a single item."""
|
||||
prompt = self._build_judge_prompt(metric_name, item)
|
||||
response = model_wrapper.invoke(prompt)
|
||||
return self._parse_score(response)
|
||||
|
||||
@staticmethod
|
||||
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
|
||||
"""Build a scoring prompt for the LLM judge."""
|
||||
parts = [
|
||||
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
|
||||
f"\nInput: {item.inputs}",
|
||||
f"\nOutput: {item.output}",
|
||||
]
|
||||
if item.expected_output:
|
||||
parts.append(f"\nExpected Output: {item.expected_output}")
|
||||
if item.context:
|
||||
parts.append(f"\nContext: {'; '.join(item.context)}")
|
||||
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _parse_score(response: str) -> float:
|
||||
"""Parse a float score from LLM response."""
|
||||
import re
|
||||
|
||||
cleaned = response.strip()
|
||||
try:
|
||||
score = float(cleaned)
|
||||
return max(0.0, min(1.0, score))
|
||||
except ValueError:
|
||||
match = re.search(r"(\d+\.?\d*)", cleaned)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
|
||||
"""Build RAGAS metric instances from canonical metric names."""
|
||||
try:
|
||||
from ragas.metrics.collections import (
|
||||
AnswerCorrectness,
|
||||
AnswerRelevancy,
|
||||
ContextPrecision,
|
||||
ContextRecall,
|
||||
ContextRelevance,
|
||||
Faithfulness,
|
||||
SemanticSimilarity,
|
||||
ToolCallAccuracy,
|
||||
)
|
||||
|
||||
# Maps canonical name → ragas metric class
|
||||
ragas_class_map: dict[str, Any] = {
|
||||
EvaluationMetricName.FAITHFULNESS: Faithfulness,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancy,
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS: AnswerCorrectness,
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY: SemanticSimilarity,
|
||||
EvaluationMetricName.CONTEXT_PRECISION: ContextPrecision,
|
||||
EvaluationMetricName.CONTEXT_RECALL: ContextRecall,
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE: ContextRelevance,
|
||||
EvaluationMetricName.TOOL_CORRECTNESS: ToolCallAccuracy,
|
||||
}
|
||||
|
||||
metrics = []
|
||||
for name in requested_metrics:
|
||||
metric_class = ragas_class_map.get(name)
|
||||
if metric_class:
|
||||
metrics.append(metric_class())
|
||||
else:
|
||||
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
|
||||
return metrics
|
||||
except ImportError:
|
||||
logger.warning("RAGAS metrics not available")
|
||||
return []
|
||||
|
||||
|
||||
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
|
||||
"""Extract the user-facing input string from the inputs dict."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
|
||||
return str(inputs.get("prompt", ""))
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
return str(inputs.get("query", ""))
|
||||
case _:
|
||||
return str(next(iter(inputs.values()), "")) if inputs else ""
|
||||
48
api/core/evaluation/frameworks/ragas/ragas_model_wrapper.py
Normal file
48
api/core/evaluation/frameworks/ragas/ragas_model_wrapper.py
Normal file
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyModelWrapper:
|
||||
"""Wraps Dify's model invocation interface for use by RAGAS as an LLM judge.
|
||||
|
||||
RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.).
|
||||
This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use.
|
||||
"""
|
||||
|
||||
def __init__(self, model_provider: str, model_name: str, tenant_id: str):
|
||||
self.model_provider = model_provider
|
||||
self.model_name = model_name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def _get_model_instance(self) -> Any:
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.model_provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=self.model_name,
|
||||
)
|
||||
return model_instance
|
||||
|
||||
def invoke(self, prompt: str) -> str:
|
||||
"""Invoke the model with a text prompt and return the text response."""
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
model_instance = self._get_model_instance()
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
|
||||
UserPromptMessage(content=prompt),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 2048},
|
||||
stream=False,
|
||||
)
|
||||
return result.message.content
|
||||
0
api/core/evaluation/judgment/__init__.py
Normal file
0
api/core/evaluation/judgment/__init__.py
Normal file
160
api/core/evaluation/judgment/processor.py
Normal file
160
api/core/evaluation/judgment/processor.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Judgment condition processor for evaluation metrics.
|
||||
|
||||
Evaluates pass/fail judgment conditions against evaluation metric values.
|
||||
Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to
|
||||
look up the metric value, then delegates the actual comparison to the
|
||||
workflow condition engine (``graphon.utils.condition.processor``).
|
||||
|
||||
The processor is intentionally decoupled from evaluation frameworks and
|
||||
runners. It operates on plain ``dict`` mappings and can be invoked
|
||||
anywhere that already has per-item metric results.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.evaluation.entities.judgment_entity import (
|
||||
JudgmentCondition,
|
||||
JudgmentConditionResult,
|
||||
JudgmentConfig,
|
||||
JudgmentResult,
|
||||
)
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
|
||||
|
||||
|
||||
class JudgmentProcessor:
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
metric_values: dict[tuple[str, str], Any],
|
||||
config: JudgmentConfig,
|
||||
) -> JudgmentResult:
|
||||
"""Evaluate all judgment conditions against the given metric values.
|
||||
|
||||
Args:
|
||||
metric_values: Mapping of ``(node_id, metric_name)`` → metric
|
||||
value (e.g. ``{("node_abc", "faithfulness"): 0.85}``).
|
||||
config: The judgment configuration with logical_operator and
|
||||
conditions.
|
||||
|
||||
Returns:
|
||||
JudgmentResult with overall pass/fail and per-condition details.
|
||||
"""
|
||||
if not config.conditions:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=[],
|
||||
)
|
||||
|
||||
condition_results: list[JudgmentConditionResult] = []
|
||||
|
||||
for condition in config.conditions:
|
||||
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
|
||||
condition_results.append(result)
|
||||
|
||||
if config.logical_operator == "and" and not result.passed:
|
||||
return JudgmentResult(
|
||||
passed=False,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
if config.logical_operator == "or" and result.passed:
|
||||
return JudgmentResult(
|
||||
passed=True,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
if config.logical_operator == "and":
|
||||
final_passed = all(r.passed for r in condition_results)
|
||||
else:
|
||||
final_passed = any(r.passed for r in condition_results)
|
||||
|
||||
return JudgmentResult(
|
||||
passed=final_passed,
|
||||
logical_operator=config.logical_operator,
|
||||
condition_results=condition_results,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_single_condition(
|
||||
metric_values: dict[tuple[str, str], Any],
|
||||
condition: JudgmentCondition,
|
||||
) -> JudgmentConditionResult:
|
||||
"""Evaluate a single judgment condition.
|
||||
|
||||
Steps:
|
||||
1. Extract ``(node_id, metric_name)`` from ``variable_selector``.
|
||||
2. Look up the metric value from ``metric_values``.
|
||||
3. Delegate comparison to the workflow condition engine.
|
||||
"""
|
||||
selector = condition.variable_selector
|
||||
if len(selector) < 2:
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"variable_selector must have at least 2 elements, got {len(selector)}",
|
||||
)
|
||||
|
||||
node_id, metric_name = selector[0], selector[1]
|
||||
actual_value = metric_values.get((node_id, metric_name))
|
||||
|
||||
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=None,
|
||||
passed=False,
|
||||
error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results",
|
||||
)
|
||||
|
||||
try:
|
||||
expected = condition.value
|
||||
# Numeric operators need the actual value coerced to int/float
|
||||
# so that the workflow engine's numeric assertions work correctly.
|
||||
coerced_actual: object = actual_value
|
||||
if (
|
||||
condition.comparison_operator in {"=", "≠", ">", "<", "≥", "≤"}
|
||||
and actual_value is not None
|
||||
and not isinstance(actual_value, (int, float, bool))
|
||||
):
|
||||
coerced_actual = float(actual_value)
|
||||
|
||||
passed = _evaluate_condition(
|
||||
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
|
||||
value=coerced_actual,
|
||||
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
|
||||
)
|
||||
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=expected,
|
||||
actual_value=actual_value,
|
||||
passed=passed,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Judgment condition evaluation failed for [%s, %s]: %s",
|
||||
node_id,
|
||||
metric_name,
|
||||
str(e),
|
||||
)
|
||||
return JudgmentConditionResult(
|
||||
variable_selector=selector,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
expected_value=condition.value,
|
||||
actual_value=actual_value,
|
||||
passed=False,
|
||||
error=str(e),
|
||||
)
|
||||
52
api/core/evaluation/runners/__init__.py
Normal file
52
api/core/evaluation/runners/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, App, CustomizedSnippet, TenantAccountJoin
|
||||
|
||||
|
||||
def get_service_account_for_app(session: Session, app_id: str) -> Account:
|
||||
"""Get the creator account for an app with tenant context set up.
|
||||
|
||||
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
|
||||
"""
|
||||
app = session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == app.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
|
||||
|
||||
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
|
||||
"""Get the creator account for a snippet with tenant context set up.
|
||||
|
||||
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
|
||||
"""
|
||||
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
|
||||
if not snippet:
|
||||
raise ValueError(f"Snippet with id {snippet_id} not found")
|
||||
|
||||
if not snippet.created_by:
|
||||
raise ValueError(f"Snippet with id {snippet_id} has no creator")
|
||||
|
||||
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
|
||||
if not account:
|
||||
raise ValueError(f"Creator account not found for snippet {snippet_id}")
|
||||
|
||||
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {account.id}")
|
||||
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
return account
|
||||
62
api/core/evaluation/runners/agent_evaluation_runner.py
Normal file
62
api/core/evaluation/runners/agent_evaluation_runner.py
Normal file
@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for agent evaluation: collects tool calls and final output."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute agent evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_agent(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_agent_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from agent NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
51
api/core/evaluation/runners/base_evaluation_runner.py
Normal file
51
api/core/evaluation/runners/base_evaluation_runner.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Base evaluation runner.
|
||||
|
||||
Provides the abstract interface for metric computation. Each concrete runner
|
||||
(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics``
|
||||
to compute scores for a specific node type.
|
||||
|
||||
Orchestration (merging results from multiple runners, applying judgment, and
|
||||
persisting to the database) is handled by the evaluation task, not the runner.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEvaluationRunner(ABC):
|
||||
"""Abstract base class for evaluation runners.
|
||||
|
||||
Runners are stateless metric calculators: they receive node execution
|
||||
results and a metric specification, then return scored results. They
|
||||
do **not** touch the database or apply judgment logic.
|
||||
"""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
self.evaluation_instance = evaluation_instance
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics on the collected results.
|
||||
|
||||
The returned ``EvaluationItemResult.index`` values are positional
|
||||
(0-based) and correspond to the order of *node_run_result_list*.
|
||||
The caller is responsible for mapping them back to the original
|
||||
dataset indices.
|
||||
"""
|
||||
...
|
||||
83
api/core/evaluation/runners/llm_evaluation_runner.py
Normal file
83
api/core/evaluation/runners/llm_evaluation_runner.py
Normal file
@ -0,0 +1,83 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for LLM evaluation: extracts prompts/outputs then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Use the evaluation instance to compute LLM metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_llm(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(
|
||||
items: list[NodeRunResult],
|
||||
) -> list[EvaluationItemInput]:
|
||||
"""Create new items from NodeRunResult for ragas evaluation.
|
||||
|
||||
Extracts prompts from process_data and concatenates them into a single
|
||||
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
|
||||
The last assistant message in outputs is used as the actual output.
|
||||
"""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
prompt = _format_prompts(item.process_data.get("prompts", []))
|
||||
output = _extract_llm_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs={"prompt": prompt},
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
|
||||
"""Concatenate a list of prompt messages into a single string for evaluation.
|
||||
|
||||
Each message is formatted as "role: text" and joined with newlines.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for msg in prompts:
|
||||
role = msg.get("role", "unknown")
|
||||
text = msg.get("text", "")
|
||||
parts.append(f"{role}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the LLM output text from NodeRunResult.outputs."""
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
61
api/core/evaluation/runners/retrieval_evaluation_runner.py
Normal file
61
api/core/evaluation/runners/retrieval_evaluation_runner.py
Normal file
@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute retrieval evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
|
||||
merged_items = []
|
||||
for i, node_result in enumerate(node_run_result_list):
|
||||
outputs = node_result.outputs
|
||||
query = self._extract_query(dict(node_result.inputs))
|
||||
result_list = outputs.get("result", [])
|
||||
contexts = [item.get("content", "") for item in result_list if item.get("content")]
|
||||
output = "\n---\n".join(contexts)
|
||||
|
||||
merged_items.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs={"query": query},
|
||||
output=output,
|
||||
context=contexts,
|
||||
)
|
||||
)
|
||||
|
||||
return self.evaluation_instance.evaluate_retrieval(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_query(inputs: dict[str, Any]) -> str:
|
||||
for key in ("query", "question", "input", "text"):
|
||||
if key in inputs:
|
||||
return str(inputs[key])
|
||||
values = list(inputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
68
api/core/evaluation/runners/snippet_evaluation_runner.py
Normal file
68
api/core/evaluation/runners/snippet_evaluation_runner.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Runner for Snippet evaluation.
|
||||
|
||||
Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from
|
||||
the evaluation instance for metric computation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnippetEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for snippet evaluation: evaluates a published Snippet workflow."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute evaluation metrics for snippet outputs."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_snippet_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from snippet NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
62
api/core/evaluation/runners/workflow_evaluation_runner.py
Normal file
62
api/core/evaluation/runners/workflow_evaluation_runner.py
Normal file
@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
|
||||
|
||||
def __init__(self, evaluation_instance: BaseEvaluationInstance):
|
||||
super().__init__(evaluation_instance)
|
||||
|
||||
def evaluate_metrics(
|
||||
self,
|
||||
node_run_result_list: list[NodeRunResult],
|
||||
default_metric: DefaultMetric,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Compute workflow evaluation metrics (end-to-end)."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
|
||||
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
|
||||
merged = []
|
||||
for i, item in enumerate(items):
|
||||
output = _extract_workflow_output(item.outputs)
|
||||
merged.append(
|
||||
EvaluationItemInput(
|
||||
index=i,
|
||||
inputs=dict(item.inputs),
|
||||
output=output,
|
||||
)
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
|
||||
"""Extract the primary output text from workflow NodeRunResult.outputs."""
|
||||
if "answer" in outputs:
|
||||
return str(outputs["answer"])
|
||||
if "text" in outputs:
|
||||
return str(outputs["text"])
|
||||
values = list(outputs.values())
|
||||
return str(values[0]) if values else ""
|
||||
@ -1,56 +1,17 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
45
api/fields/snippet_fields.py
Normal file
45
api/fields/snippet_fields.py
Normal file
@ -0,0 +1,45 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
# Snippet list item fields (lightweight for list display)
|
||||
snippet_list_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.Integer,
|
||||
"use_count": fields.Integer,
|
||||
"is_published": fields.Boolean,
|
||||
"icon_info": fields.Raw,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Full snippet fields (includes creator info and graph data)
|
||||
snippet_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.Integer,
|
||||
"use_count": fields.Integer,
|
||||
"is_published": fields.Boolean,
|
||||
"icon_info": fields.Raw,
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
"input_fields": fields.Raw(attribute="input_fields_list"),
|
||||
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Pagination response fields
|
||||
snippet_pagination_fields = {
|
||||
"data": fields.List(fields.Nested(snippet_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
@ -14,6 +14,7 @@ workflow_app_log_partial_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
|
||||
"details": fields.Raw(attribute="details"),
|
||||
"evaluation": fields.Raw(attribute="evaluation", default=None),
|
||||
"created_from": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
"""add_customized_snippets_table
|
||||
|
||||
Revision ID: 1c05e80d2380
|
||||
Revises: 788d3099ae3a
|
||||
Create Date: 2026-01-29 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1c05e80d2380"
|
||||
down_revision = "788d3099ae3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table(
|
||||
"customized_snippets",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("graph", sa.Text(), nullable=True),
|
||||
sa.Column("input_fields", sa.Text(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
else:
|
||||
op.create_table(
|
||||
"customized_snippets",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", models.types.LongText(), nullable=True),
|
||||
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("graph", models.types.LongText(), nullable=True),
|
||||
sa.Column("input_fields", models.types.LongText(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||
batch_op.drop_index("customized_snippet_tenant_idx")
|
||||
|
||||
op.drop_table("customized_snippets")
|
||||
@ -0,0 +1,116 @@
|
||||
"""add_evaluation_tables
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 1c05e80d2380
|
||||
Create Date: 2026-03-03 00:01:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "1c05e80d2380"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# evaluation_configurations
|
||||
op.create_table(
|
||||
"evaluation_configurations",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("target_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("target_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True),
|
||||
sa.Column("evaluation_model", sa.String(length=255), nullable=True),
|
||||
sa.Column("metrics_config", models.types.LongText(), nullable=True),
|
||||
sa.Column("judgement_conditions", models.types.LongText(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
|
||||
)
|
||||
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
|
||||
)
|
||||
|
||||
# evaluation_runs
|
||||
op.create_table(
|
||||
"evaluation_runs",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("target_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("target_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")),
|
||||
sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("result_file_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("metrics_summary", models.types.LongText(), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("celery_task_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("started_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
|
||||
)
|
||||
batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False)
|
||||
|
||||
# evaluation_run_items
|
||||
op.create_table(
|
||||
"evaluation_run_items",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("item_index", sa.Integer(), nullable=False),
|
||||
sa.Column("inputs", models.types.LongText(), nullable=True),
|
||||
sa.Column("expected_output", models.types.LongText(), nullable=True),
|
||||
sa.Column("context", models.types.LongText(), nullable=True),
|
||||
sa.Column("actual_output", models.types.LongText(), nullable=True),
|
||||
sa.Column("metrics", models.types.LongText(), nullable=True),
|
||||
sa.Column("metadata_json", models.types.LongText(), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("overall_score", sa.Float(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
|
||||
batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False)
|
||||
batch_op.create_index(
|
||||
"evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False
|
||||
)
|
||||
batch_op.create_index("evaluation_run_item_workflow_run_idx", ["workflow_run_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
|
||||
batch_op.drop_index("evaluation_run_item_workflow_run_idx")
|
||||
batch_op.drop_index("evaluation_run_item_index_idx")
|
||||
batch_op.drop_index("evaluation_run_item_run_idx")
|
||||
op.drop_table("evaluation_run_items")
|
||||
|
||||
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
|
||||
batch_op.drop_index("evaluation_run_status_idx")
|
||||
batch_op.drop_index("evaluation_run_target_idx")
|
||||
op.drop_table("evaluation_runs")
|
||||
|
||||
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
|
||||
batch_op.drop_index("evaluation_configuration_target_idx")
|
||||
op.drop_table("evaluation_configurations")
|
||||
@ -0,0 +1,25 @@
|
||||
"""merge migration heads
|
||||
|
||||
Revision ID: 4c60d8d3ee74
|
||||
Revises: fce013ca180e, a1b2c3d4e5f6
|
||||
Create Date: 2026-03-17 17:21:12.105536
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4c60d8d3ee74'
|
||||
down_revision = ('fce013ca180e', 'a1b2c3d4e5f6')
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
@ -33,6 +33,13 @@ from .enums import (
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowTriggerStatus,
|
||||
)
|
||||
from .evaluation import (
|
||||
EvaluationConfiguration,
|
||||
EvaluationRun,
|
||||
EvaluationRunItem,
|
||||
EvaluationRunStatus,
|
||||
EvaluationTargetType,
|
||||
)
|
||||
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
|
||||
from .human_input import HumanInputForm
|
||||
from .model import (
|
||||
@ -80,6 +87,7 @@ from .provider import (
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from .snippet import CustomizedSnippet, SnippetType
|
||||
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from .task import CeleryTask, CeleryTaskSet
|
||||
from .tools import (
|
||||
@ -139,6 +147,7 @@ __all__ = [
|
||||
"Conversation",
|
||||
"ConversationVariable",
|
||||
"CreatorUserRole",
|
||||
"CustomizedSnippet",
|
||||
"DataSourceApiKeyAuthBinding",
|
||||
"DataSourceOauthBinding",
|
||||
"Dataset",
|
||||
@ -156,6 +165,11 @@ __all__ = [
|
||||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EndUser",
|
||||
"EvaluationConfiguration",
|
||||
"EvaluationRun",
|
||||
"EvaluationRunItem",
|
||||
"EvaluationRunStatus",
|
||||
"EvaluationTargetType",
|
||||
"ExecutionExtraContent",
|
||||
"ExporleBanner",
|
||||
"ExternalKnowledgeApis",
|
||||
@ -183,6 +197,7 @@ __all__ = [
|
||||
"RecommendedApp",
|
||||
"SavedMessage",
|
||||
"Site",
|
||||
"SnippetType",
|
||||
"Tag",
|
||||
"TagBinding",
|
||||
"Tenant",
|
||||
|
||||
205
api/models/evaluation.py
Normal file
205
api/models/evaluation.py
Normal file
@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, Float, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
||||
class EvaluationRunStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class EvaluationTargetType(StrEnum):
|
||||
APP = "app"
|
||||
SNIPPETS = "snippets"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
|
||||
class EvaluationConfiguration(Base):
|
||||
"""Stores evaluation configuration for each target (App or Snippet)."""
|
||||
|
||||
__tablename__ = "evaluation_configurations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
|
||||
sa.Index("evaluation_configuration_target_idx", "tenant_id", "target_type", "target_id"),
|
||||
sa.Index("evaluation_configuration_workflow_idx", "customized_workflow_id"),
|
||||
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
evaluation_model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
evaluation_model: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
metrics_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
judgement_conditions: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
customized_workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
updated_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def metrics_config_dict(self) -> dict[str, Any]:
|
||||
if self.metrics_config:
|
||||
return json.loads(self.metrics_config)
|
||||
return {}
|
||||
|
||||
@metrics_config_dict.setter
|
||||
def metrics_config_dict(self, value: dict[str, Any]) -> None:
|
||||
self.metrics_config = json.dumps(value)
|
||||
|
||||
@property
|
||||
def default_metrics_list(self) -> list[dict[str, Any]]:
|
||||
"""Extract default_metrics from the stored metrics_config JSON."""
|
||||
config = self.metrics_config_dict
|
||||
return config.get("default_metrics", [])
|
||||
|
||||
@property
|
||||
def customized_metrics_dict(self) -> dict[str, Any] | None:
|
||||
"""Extract customized_metrics from the stored metrics_config JSON."""
|
||||
config = self.metrics_config_dict
|
||||
return config.get("customized_metrics")
|
||||
|
||||
@property
|
||||
def judgment_config_dict(self) -> dict[str, Any] | None:
|
||||
"""Return judgment config (stored in the judgement_conditions column)."""
|
||||
if self.judgement_conditions:
|
||||
parsed = json.loads(self.judgement_conditions)
|
||||
return parsed if parsed else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def judgement_conditions_dict(self) -> dict[str, Any]:
|
||||
if self.judgement_conditions:
|
||||
return json.loads(self.judgement_conditions)
|
||||
return {}
|
||||
|
||||
@judgement_conditions_dict.setter
|
||||
def judgement_conditions_dict(self, value: dict[str, Any]) -> None:
|
||||
self.judgement_conditions = json.dumps(value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EvaluationConfiguration(id={self.id}, target={self.target_type}:{self.target_id})>"
|
||||
|
||||
|
||||
class EvaluationRun(Base):
|
||||
"""Stores each evaluation run record."""
|
||||
|
||||
__tablename__ = "evaluation_runs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
|
||||
sa.Index("evaluation_run_target_idx", "tenant_id", "target_type", "target_id"),
|
||||
sa.Index("evaluation_run_status_idx", "tenant_id", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
evaluation_config_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False, default=EvaluationRunStatus.PENDING)
|
||||
dataset_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
result_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
total_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
completed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
failed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def progress(self) -> float:
|
||||
if self.total_items == 0:
|
||||
return 0.0
|
||||
return (self.completed_items + self.failed_items) / self.total_items
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EvaluationRun(id={self.id}, status={self.status})>"
|
||||
|
||||
|
||||
class EvaluationRunItem(Base):
|
||||
"""Stores per-row evaluation results."""
|
||||
|
||||
__tablename__ = "evaluation_run_items"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
|
||||
sa.Index("evaluation_run_item_run_idx", "evaluation_run_id"),
|
||||
sa.Index("evaluation_run_item_index_idx", "evaluation_run_id", "item_index"),
|
||||
sa.Index("evaluation_run_item_workflow_run_idx", "workflow_run_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
evaluation_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
item_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
inputs: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
expected_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
context: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
actual_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
metrics: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
judgment: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
metadata_json: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def inputs_dict(self) -> dict[str, Any]:
|
||||
if self.inputs:
|
||||
return json.loads(self.inputs)
|
||||
return {}
|
||||
|
||||
@property
|
||||
def metrics_list(self) -> list[dict[str, Any]]:
|
||||
if self.metrics:
|
||||
return json.loads(self.metrics)
|
||||
return []
|
||||
|
||||
@property
|
||||
def judgment_dict(self) -> dict[str, Any]:
|
||||
if self.judgment:
|
||||
return json.loads(self.judgment)
|
||||
return {}
|
||||
|
||||
@property
|
||||
def metadata_dict(self) -> dict[str, Any]:
|
||||
if self.metadata_json:
|
||||
return json.loads(self.metadata_json)
|
||||
return {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EvaluationRunItem(id={self.id}, run={self.evaluation_run_id}, index={self.item_index})>"
|
||||
101
api/models/snippet.py
Normal file
101
api/models/snippet.py
Normal file
@ -0,0 +1,101 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import AdjustedJSON, LongText, StringUUID
|
||||
|
||||
|
||||
class SnippetType(StrEnum):
|
||||
"""Snippet Type Enum"""
|
||||
|
||||
NODE = "node"
|
||||
GROUP = "group"
|
||||
|
||||
|
||||
class CustomizedSnippet(Base):
|
||||
"""
|
||||
Customized Snippet Model
|
||||
|
||||
Stores reusable workflow components (nodes or node groups) that can be
|
||||
shared across applications within a workspace.
|
||||
"""
|
||||
|
||||
__tablename__ = "customized_snippets"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
|
||||
|
||||
# Workflow reference for published version
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
# State flags
|
||||
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
|
||||
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
|
||||
# Visual customization
|
||||
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
|
||||
|
||||
# Snippet configuration (stored as JSON text)
|
||||
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def graph_dict(self) -> dict[str, Any]:
|
||||
"""Get graph from associated workflow."""
|
||||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
workflow = db.session.get(Workflow, self.workflow_id)
|
||||
if workflow:
|
||||
return json.loads(workflow.graph) if workflow.graph else {}
|
||||
return {}
|
||||
|
||||
@property
|
||||
def input_fields_list(self) -> list[dict[str, Any]]:
|
||||
"""Parse input_fields JSON to list."""
|
||||
return json.loads(self.input_fields) if self.input_fields else []
|
||||
|
||||
@property
|
||||
def created_by_account(self) -> Account | None:
|
||||
"""Get the account that created this snippet."""
|
||||
if self.created_by:
|
||||
return db.session.get(Account, self.created_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def updated_by_account(self) -> Account | None:
|
||||
"""Get the account that last updated this snippet."""
|
||||
if self.updated_by:
|
||||
return db.session.get(Account, self.updated_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def version_str(self) -> str:
|
||||
"""Get version as string for API response."""
|
||||
return str(self.version)
|
||||
@ -106,6 +106,8 @@ class WorkflowType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
SNIPPET = "snippet"
|
||||
EVALUATION = "evaluation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
|
||||
@ -234,6 +234,12 @@ storage = [
|
||||
############################################################
|
||||
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
|
||||
############################################################
|
||||
# [ Evaluation ] dependency group
|
||||
# Required for evaluation frameworks
|
||||
############################################################
|
||||
evaluation = ["ragas>=0.2.0", "deepeval>=2.0.0"]
|
||||
|
||||
############################################################
|
||||
# [ VDB ] workspace plugins — hollow packages under providers/vdb/*
|
||||
# Each declares its own third-party deps and registers dify.vector_backends entry points.
|
||||
|
||||
@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
|
||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db import session_factory
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||
|
||||
@ -106,7 +107,7 @@ class AppGenerateService:
|
||||
quota_charge = unlimited()
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
|
||||
except QuotaExceededError:
|
||||
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
||||
|
||||
@ -116,6 +117,7 @@ class AppGenerateService:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
quota_charge.commit()
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||
)
|
||||
|
||||
@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -131,9 +132,10 @@ class AsyncWorkflowService:
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Check and consume quota
|
||||
# 7. Reserve quota (commit after successful dispatch)
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
|
||||
except QuotaExceededError as e:
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
@ -153,13 +155,18 @@ class AsyncWorkflowService:
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
try:
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
|
||||
@ -32,6 +32,102 @@ class SubscriptionPlan(TypedDict):
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class QuotaReserveResult(TypedDict):
|
||||
reservation_id: str
|
||||
available: int
|
||||
reserved: int
|
||||
|
||||
|
||||
class QuotaCommitResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
refunded: int
|
||||
|
||||
|
||||
class QuotaReleaseResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
released: int
|
||||
|
||||
|
||||
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
|
||||
|
||||
class _VectorSpaceQuota(TypedDict):
|
||||
size: float
|
||||
limit: int
|
||||
|
||||
|
||||
class _KnowledgeRateLimit(TypedDict):
|
||||
# NOTE (hj24):
|
||||
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
|
||||
# 2. Keep it for compatibility for now, can be deprecated in future versions.
|
||||
size: NotRequired[int]
|
||||
# NOTE END
|
||||
limit: int
|
||||
|
||||
|
||||
class _BillingSubscription(TypedDict):
|
||||
plan: str
|
||||
interval: str
|
||||
education: bool
|
||||
|
||||
|
||||
class BillingInfo(TypedDict):
|
||||
"""Response of /subscription/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
|
||||
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
|
||||
1. validate_python in non-strict mode will coerce it to the expected type
|
||||
2. In strict mode, it will raise ValidationError
|
||||
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
docs_processing: str
|
||||
can_replace_logo: bool
|
||||
model_load_balancing_enabled: bool
|
||||
knowledge_pipeline_publish_enabled: bool
|
||||
next_credit_reset_date: NotRequired[int]
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class _TenantFeatureQuota(TypedDict):
|
||||
usage: int
|
||||
limit: int
|
||||
reset_date: NotRequired[int]
|
||||
|
||||
|
||||
class TenantFeatureQuotaInfo(TypedDict):
|
||||
"""Response of /quota/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Same convention as BillingInfo: billing may return int fields as str,
|
||||
always keep non-strict mode to auto-coerce.
|
||||
"""
|
||||
|
||||
trigger_event: _TenantFeatureQuota
|
||||
api_rate_limit: _TenantFeatureQuota
|
||||
|
||||
|
||||
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
|
||||
|
||||
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
@ -149,11 +245,63 @@ class BillingService:
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
"""Deprecated: Use get_quota_info instead."""
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
return _tenant_feature_quota_info_adapter.validate_python(
|
||||
cls._send_request("GET", "/quota/info", params=params)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def quota_reserve(
|
||||
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
|
||||
) -> QuotaReserveResult:
|
||||
"""Reserve quota before task execution."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"request_id": request_id,
|
||||
"amount": amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_commit(
|
||||
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
|
||||
) -> QuotaCommitResult:
|
||||
"""Commit a reservation with actual consumption."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
"actual_amount": actual_amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
|
||||
"""Release a reservation (cancel, return frozen quota)."""
|
||||
return _quota_release_adapter.validate_python(
|
||||
cls._send_request(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
21
api/services/errors/evaluation.py
Normal file
21
api/services/errors/evaluation.py
Normal file
@ -0,0 +1,21 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EvaluationFrameworkNotConfiguredError(BaseServiceError):
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description or "Evaluation framework is not configured. Set EVALUATION_FRAMEWORK env var.")
|
||||
|
||||
|
||||
class EvaluationNotFoundError(BaseServiceError):
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description or "Evaluation not found.")
|
||||
|
||||
|
||||
class EvaluationDatasetInvalidError(BaseServiceError):
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description or "Evaluation dataset is invalid.")
|
||||
|
||||
|
||||
class EvaluationMaxConcurrentRunsError(BaseServiceError):
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description or "Maximum number of concurrent evaluation runs reached.")
|
||||
985
api/services/evaluation_service.py
Normal file
985
api/services/evaluation_service.py
Normal file
@ -0,0 +1,985 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from openpyxl import Workbook, load_workbook
|
||||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
METRIC_NODE_TYPE_MAPPING,
|
||||
METRIC_VALUE_TYPE_MAPPING,
|
||||
DefaultMetric,
|
||||
EvaluationCategory,
|
||||
EvaluationConfigData,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationMetricName,
|
||||
EvaluationRunData,
|
||||
EvaluationRunRequest,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.evaluation_manager import EvaluationManager
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
from models.evaluation import (
|
||||
EvaluationConfiguration,
|
||||
EvaluationRun,
|
||||
EvaluationRunItem,
|
||||
EvaluationRunStatus,
|
||||
)
|
||||
from models.model import App, AppMode
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import Workflow
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationService:
|
||||
"""
|
||||
Service for evaluation-related operations.
|
||||
|
||||
Provides functionality to generate evaluation dataset templates
|
||||
based on App or Snippet input parameters.
|
||||
"""
|
||||
|
||||
# Excluded app modes that don't support evaluation templates
|
||||
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
|
||||
|
||||
@classmethod
|
||||
def generate_dataset_template(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
) -> tuple[bytes, str]:
|
||||
"""
|
||||
Generate evaluation dataset template as XLSX bytes.
|
||||
|
||||
Creates an XLSX file with headers based on the evaluation target's input parameters.
|
||||
The first column is index, followed by input parameter columns.
|
||||
|
||||
:param target: App or CustomizedSnippet instance
|
||||
:param target_type: Target type string ("app" or "snippet")
|
||||
:return: Tuple of (xlsx_content_bytes, filename)
|
||||
:raises ValueError: If target type is not supported or app mode is excluded
|
||||
"""
|
||||
# Validate target type
|
||||
if target_type == "app":
|
||||
if not isinstance(target, App):
|
||||
raise ValueError("Invalid target: expected App instance")
|
||||
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
|
||||
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
|
||||
input_fields = cls._get_app_input_fields(target)
|
||||
elif target_type == "snippet":
|
||||
if not isinstance(target, CustomizedSnippet):
|
||||
raise ValueError("Invalid target: expected CustomizedSnippet instance")
|
||||
input_fields = cls._get_snippet_input_fields(target)
|
||||
else:
|
||||
raise ValueError(f"Unsupported target type: {target_type}")
|
||||
|
||||
# Generate XLSX template
|
||||
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
|
||||
|
||||
# Build filename
|
||||
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
|
||||
filename = f"{truncated_name}-evaluation-dataset.xlsx"
|
||||
|
||||
return xlsx_content, filename
|
||||
|
||||
@classmethod
|
||||
def _get_app_input_fields(cls, app: App) -> list[dict]:
|
||||
"""
|
||||
Get input fields from App's workflow.
|
||||
|
||||
:param app: App instance
|
||||
:return: List of input field definitions
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
|
||||
if not workflow:
|
||||
return []
|
||||
|
||||
# Get user input form from workflow
|
||||
user_input_form = workflow.user_input_form()
|
||||
return user_input_form
|
||||
|
||||
@classmethod
|
||||
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
|
||||
"""
|
||||
Get input fields from Snippet.
|
||||
|
||||
Tries to get from snippet's own input_fields first,
|
||||
then falls back to workflow's user_input_form.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: List of input field definitions
|
||||
"""
|
||||
# Try snippet's own input_fields first
|
||||
input_fields = snippet.input_fields_list
|
||||
if input_fields:
|
||||
return input_fields
|
||||
|
||||
# Fallback to workflow's user_input_form
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if workflow:
|
||||
return workflow.user_input_form()
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
|
||||
"""
|
||||
Generate XLSX template file content.
|
||||
|
||||
Creates a workbook with:
|
||||
- First row as header row with "index" and input field names
|
||||
- Styled header with background color and borders
|
||||
- Empty data rows ready for user input
|
||||
|
||||
:param input_fields: List of input field definitions
|
||||
:param target_name: Name of the target (for sheet name)
|
||||
:return: XLSX file content as bytes
|
||||
"""
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
if ws is None:
|
||||
ws = wb.create_sheet("Evaluation Dataset")
|
||||
|
||||
sheet_name = "Evaluation Dataset"
|
||||
ws.title = sheet_name
|
||||
|
||||
header_font = Font(bold=True, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
|
||||
header_alignment = Alignment(horizontal="center", vertical="center")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin"),
|
||||
right=Side(style="thin"),
|
||||
top=Side(style="thin"),
|
||||
bottom=Side(style="thin"),
|
||||
)
|
||||
|
||||
# Build header row
|
||||
headers = ["index"]
|
||||
|
||||
for field in input_fields:
|
||||
field_label = str(field.get("label") or field.get("variable") or "")
|
||||
headers.append(field_label)
|
||||
|
||||
# Write header row
|
||||
for col_idx, header in enumerate(headers, start=1):
|
||||
cell = ws.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = header_alignment
|
||||
cell.border = thin_border
|
||||
|
||||
# Set column widths
|
||||
ws.column_dimensions["A"].width = 10 # index column
|
||||
for col_idx in range(2, len(headers) + 1):
|
||||
ws.column_dimensions[get_column_letter(col_idx)].width = 20
|
||||
|
||||
# Add one empty row with row number for user reference
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
cell = ws.cell(row=2, column=col_idx, value="")
|
||||
cell.border = thin_border
|
||||
if col_idx == 1:
|
||||
cell.value = 1
|
||||
cell.alignment = Alignment(horizontal="center")
|
||||
|
||||
# Save to bytes
|
||||
output = io.BytesIO()
|
||||
wb.save(output)
|
||||
output.seek(0)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
@classmethod
|
||||
def generate_retrieval_dataset_template(cls) -> tuple[bytes, str]:
|
||||
"""Generate evaluation dataset XLSX template for knowledge base retrieval.
|
||||
|
||||
The template contains three columns: ``index``, ``query``, and
|
||||
``expected_output``. Callers upload a filled copy and start an
|
||||
evaluation run with ``target_type="dataset"``.
|
||||
|
||||
:returns: (xlsx_content_bytes, filename)
|
||||
"""
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
if ws is None:
|
||||
ws = wb.create_sheet("Evaluation Dataset")
|
||||
ws.title = "Evaluation Dataset"
|
||||
|
||||
header_font = Font(bold=True, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
|
||||
header_alignment = Alignment(horizontal="center", vertical="center")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin"),
|
||||
right=Side(style="thin"),
|
||||
top=Side(style="thin"),
|
||||
bottom=Side(style="thin"),
|
||||
)
|
||||
|
||||
headers = ["index", "query", "expected_output"]
|
||||
for col_idx, header in enumerate(headers, start=1):
|
||||
cell = ws.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = header_alignment
|
||||
cell.border = thin_border
|
||||
|
||||
ws.column_dimensions["A"].width = 10
|
||||
ws.column_dimensions["B"].width = 30
|
||||
ws.column_dimensions["C"].width = 30
|
||||
|
||||
# Add one sample row
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
cell = ws.cell(row=2, column=col_idx, value="")
|
||||
cell.border = thin_border
|
||||
if col_idx == 1:
|
||||
cell.value = 1
|
||||
cell.alignment = Alignment(horizontal="center")
|
||||
|
||||
output = io.BytesIO()
|
||||
wb.save(output)
|
||||
output.seek(0)
|
||||
return output.getvalue(), "retrieval-evaluation-dataset.xlsx"
|
||||
|
||||
# ---- Evaluation Configuration CRUD ----
|
||||
|
||||
@classmethod
|
||||
def get_evaluation_config(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
) -> EvaluationConfiguration | None:
|
||||
return (
|
||||
session.query(EvaluationConfiguration)
|
||||
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_evaluation_config(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
account_id: str,
|
||||
data: EvaluationConfigData,
|
||||
) -> EvaluationConfiguration:
|
||||
config = cls.get_evaluation_config(session, tenant_id, target_type, target_id)
|
||||
if config is None:
|
||||
config = EvaluationConfiguration(
|
||||
tenant_id=tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
created_by=account_id,
|
||||
updated_by=account_id,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
config.evaluation_model_provider = data.evaluation_model_provider
|
||||
config.evaluation_model = data.evaluation_model
|
||||
config.metrics_config = json.dumps(
|
||||
{
|
||||
"default_metrics": [m.model_dump() for m in data.default_metrics],
|
||||
"customized_metrics": data.customized_metrics.model_dump() if data.customized_metrics else None,
|
||||
}
|
||||
)
|
||||
config.judgement_conditions = json.dumps(data.judgment_config.model_dump() if data.judgment_config else {})
|
||||
config.customized_workflow_id = (
|
||||
data.customized_metrics.evaluation_workflow_id if data.customized_metrics else None
|
||||
)
|
||||
config.updated_by = account_id
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def list_targets_by_customized_workflow(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
customized_workflow_id: str,
|
||||
) -> list[EvaluationConfiguration]:
|
||||
"""Return all evaluation configs that reference the given workflow as customized metrics."""
|
||||
from sqlalchemy import select
|
||||
|
||||
return list(
|
||||
session.scalars(
|
||||
select(EvaluationConfiguration).where(
|
||||
EvaluationConfiguration.tenant_id == tenant_id,
|
||||
EvaluationConfiguration.customized_workflow_id == customized_workflow_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
# ---- Evaluation Run Management ----
|
||||
|
||||
@classmethod
|
||||
def start_evaluation_run(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
account_id: str,
|
||||
dataset_file_content: bytes,
|
||||
run_request: EvaluationRunRequest,
|
||||
) -> EvaluationRun:
|
||||
"""Validate dataset, create run record, dispatch Celery task.
|
||||
|
||||
Saves the provided parameters as the latest EvaluationConfiguration
|
||||
before creating the run.
|
||||
"""
|
||||
# Check framework is configured
|
||||
evaluation_instance = EvaluationManager.get_evaluation_instance()
|
||||
if evaluation_instance is None:
|
||||
raise EvaluationFrameworkNotConfiguredError()
|
||||
|
||||
# Save as latest EvaluationConfiguration
|
||||
config = cls.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
account_id=account_id,
|
||||
data=run_request,
|
||||
)
|
||||
|
||||
# Check concurrent run limit
|
||||
active_runs = (
|
||||
session.query(EvaluationRun)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter(EvaluationRun.status.in_([EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING]))
|
||||
.count()
|
||||
)
|
||||
max_concurrent = dify_config.EVALUATION_MAX_CONCURRENT_RUNS
|
||||
if active_runs >= max_concurrent:
|
||||
raise EvaluationMaxConcurrentRunsError(f"Maximum concurrent runs ({max_concurrent}) reached.")
|
||||
|
||||
# Parse dataset
|
||||
items = cls._parse_dataset(dataset_file_content)
|
||||
max_rows = dify_config.EVALUATION_MAX_DATASET_ROWS
|
||||
if len(items) > max_rows:
|
||||
raise EvaluationDatasetInvalidError(f"Dataset has {len(items)} rows, max is {max_rows}.")
|
||||
|
||||
# Create evaluation run
|
||||
evaluation_run = EvaluationRun(
|
||||
tenant_id=tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
evaluation_config_id=config.id,
|
||||
status=EvaluationRunStatus.PENDING,
|
||||
total_items=len(items),
|
||||
created_by=account_id,
|
||||
)
|
||||
session.add(evaluation_run)
|
||||
session.commit()
|
||||
session.refresh(evaluation_run)
|
||||
|
||||
# Build Celery task data
|
||||
run_data = EvaluationRunData(
|
||||
evaluation_run_id=evaluation_run.id,
|
||||
tenant_id=tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
evaluation_model_provider=run_request.evaluation_model_provider,
|
||||
evaluation_model=run_request.evaluation_model,
|
||||
default_metrics=run_request.default_metrics,
|
||||
customized_metrics=run_request.customized_metrics,
|
||||
judgment_config=run_request.judgment_config,
|
||||
input_list=items,
|
||||
)
|
||||
|
||||
# Dispatch Celery task
|
||||
from tasks.evaluation_task import run_evaluation
|
||||
|
||||
task = run_evaluation.delay(run_data.model_dump())
|
||||
evaluation_run.celery_task_id = task.id
|
||||
session.commit()
|
||||
|
||||
return evaluation_run
|
||||
|
||||
@classmethod
|
||||
def get_evaluation_runs(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[EvaluationRun], int]:
|
||||
"""Query evaluation run history with pagination."""
|
||||
query = (
|
||||
session.query(EvaluationRun)
|
||||
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
|
||||
.order_by(EvaluationRun.created_at.desc())
|
||||
)
|
||||
total = query.count()
|
||||
runs = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||
return runs, total
|
||||
|
||||
@classmethod
|
||||
def get_evaluation_run_detail(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
run_id: str,
|
||||
) -> EvaluationRun:
|
||||
run = session.query(EvaluationRun).filter_by(id=run_id, tenant_id=tenant_id).first()
|
||||
if not run:
|
||||
raise EvaluationNotFoundError("Evaluation run not found.")
|
||||
return run
|
||||
|
||||
@classmethod
|
||||
def get_evaluation_run_items(
|
||||
cls,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[EvaluationRunItem], int]:
|
||||
"""Query evaluation run items with pagination."""
|
||||
query = (
|
||||
session.query(EvaluationRunItem)
|
||||
.filter_by(evaluation_run_id=run_id)
|
||||
.order_by(EvaluationRunItem.item_index.asc())
|
||||
)
|
||||
total = query.count()
|
||||
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||
return items, total
|
||||
|
||||
@classmethod
|
||||
def cancel_evaluation_run(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
run_id: str,
|
||||
) -> EvaluationRun:
|
||||
run = cls.get_evaluation_run_detail(session, tenant_id, run_id)
|
||||
if run.status not in (EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING):
|
||||
raise ValueError(f"Cannot cancel evaluation run in status: {run.status}")
|
||||
|
||||
run.status = EvaluationRunStatus.CANCELLED
|
||||
|
||||
# Revoke Celery task if running
|
||||
if run.celery_task_id:
|
||||
try:
|
||||
from celery import current_app as celery_app
|
||||
|
||||
celery_app.control.revoke(run.celery_task_id, terminate=True)
|
||||
except Exception:
|
||||
logger.exception("Failed to revoke Celery task %s", run.celery_task_id)
|
||||
|
||||
session.commit()
|
||||
return run
|
||||
|
||||
@classmethod
|
||||
def get_supported_metrics(cls, category: EvaluationCategory) -> list[str]:
|
||||
return EvaluationManager.get_supported_metrics(category)
|
||||
|
||||
@staticmethod
|
||||
def get_available_metrics() -> list[str]:
|
||||
"""Return the centrally-defined list of evaluation metrics."""
|
||||
return [m.value for m in EvaluationMetricName]
|
||||
|
||||
@classmethod
|
||||
def _nodes_for_metrics_from_workflow(
|
||||
cls,
|
||||
workflow: Workflow,
|
||||
metrics: list[str],
|
||||
) -> dict[str, list[dict[str, str]]]:
|
||||
node_type_to_nodes: dict[str, list[dict[str, str]]] = {}
|
||||
for node_id, node_data in workflow.walk_nodes():
|
||||
ntype = node_data.get("type", "")
|
||||
node_type_to_nodes.setdefault(ntype, []).append(
|
||||
NodeInfo(node_id=node_id, type=ntype, title=node_data.get("title", "")).model_dump()
|
||||
)
|
||||
|
||||
result: dict[str, list[dict[str, str]]] = {}
|
||||
for metric in metrics:
|
||||
required_node_type = METRIC_NODE_TYPE_MAPPING.get(metric)
|
||||
if required_node_type is None:
|
||||
result[metric] = []
|
||||
continue
|
||||
result[metric] = node_type_to_nodes.get(required_node_type, [])
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _union_supported_metric_names(cls) -> list[str]:
|
||||
"""Metric names the current evaluation framework supports for any :class:`EvaluationCategory`."""
|
||||
ordered: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for category in EvaluationCategory:
|
||||
for name in cls.get_supported_metrics(category):
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
ordered.append(name)
|
||||
return ordered
|
||||
|
||||
@classmethod
|
||||
def get_default_metrics_with_nodes_for_published_target(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
) -> list[DefaultMetric]:
|
||||
"""List default metrics and matching nodes using only the *published* workflow graph.
|
||||
|
||||
Metrics are those supported by the configured evaluation framework and present in
|
||||
:data:`METRIC_NODE_TYPE_MAPPING`. Node lists are derived from the published workflow only
|
||||
(no draft fallback).
|
||||
"""
|
||||
workflow = cls._resolve_published_workflow(target, target_type)
|
||||
if not workflow:
|
||||
return []
|
||||
|
||||
supported = cls._union_supported_metric_names()
|
||||
metric_names = sorted(m for m in supported if m in METRIC_NODE_TYPE_MAPPING)
|
||||
if not metric_names:
|
||||
return []
|
||||
|
||||
nodes_by_metric = cls._nodes_for_metrics_from_workflow(workflow, metric_names)
|
||||
return [
|
||||
DefaultMetric(
|
||||
metric=m,
|
||||
value_type=METRIC_VALUE_TYPE_MAPPING.get(m, "number"),
|
||||
node_info_list=[NodeInfo.model_validate(n) for n in nodes_by_metric.get(m, [])],
|
||||
)
|
||||
for m in metric_names
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_nodes_for_metrics(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
metrics: list[str] | None = None,
|
||||
) -> dict[str, list[dict[str, str]]]:
|
||||
"""Return node info grouped by metric (or all nodes when *metrics* is empty).
|
||||
|
||||
:param target: App or CustomizedSnippet instance.
|
||||
:param target_type: ``"app"`` or ``"snippets"``.
|
||||
:param metrics: Optional list of metric names to filter by.
|
||||
When *None* or empty, returns ``{"all": [<every node>]}``.
|
||||
:returns: ``{metric_name: [NodeInfo dict, ...]}`` or
|
||||
``{"all": [NodeInfo dict, ...]}``.
|
||||
"""
|
||||
workflow = cls._resolve_workflow(target, target_type)
|
||||
if not workflow:
|
||||
return {"all": []} if not metrics else {m: [] for m in metrics}
|
||||
|
||||
if not metrics:
|
||||
all_nodes = [
|
||||
NodeInfo(node_id=node_id, type=node_data.get("type", ""), title=node_data.get("title", "")).model_dump()
|
||||
for node_id, node_data in workflow.walk_nodes()
|
||||
]
|
||||
return {"all": all_nodes}
|
||||
|
||||
return cls._nodes_for_metrics_from_workflow(workflow, metrics)
|
||||
|
||||
@classmethod
|
||||
def _resolve_published_workflow(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
) -> Workflow | None:
|
||||
"""Resolve only the published workflow for the target (no draft fallback)."""
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
return SnippetService().get_published_workflow(snippet=target)
|
||||
if target_type == "app" and isinstance(target, App):
|
||||
return WorkflowService().get_published_workflow(app_model=target)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_workflow(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
) -> Workflow | None:
|
||||
"""Resolve the *published* (preferred) or *draft* workflow for the target."""
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=target)
|
||||
if not workflow:
|
||||
workflow = snippet_service.get_draft_workflow(snippet=target)
|
||||
return workflow
|
||||
elif target_type == "app" and isinstance(target, App):
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=target)
|
||||
if not workflow:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=target)
|
||||
return workflow
|
||||
return None
|
||||
|
||||
# ---- Category Resolution ----
|
||||
|
||||
@classmethod
|
||||
def _resolve_evaluation_category(cls, default_metrics: list[DefaultMetric]) -> EvaluationCategory:
|
||||
"""Derive evaluation category from default_metrics node_info types.
|
||||
|
||||
Uses the type of the first node_info found in default_metrics.
|
||||
Falls back to LLM if no metrics are provided.
|
||||
"""
|
||||
for metric in default_metrics:
|
||||
for node_info in metric.node_info_list:
|
||||
try:
|
||||
return EvaluationCategory(node_info.type)
|
||||
except ValueError:
|
||||
continue
|
||||
return EvaluationCategory.LLM
|
||||
|
||||
@classmethod
|
||||
def execute_targets(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
input_list: list[EvaluationDatasetInput],
|
||||
max_workers: int = 5,
|
||||
) -> tuple[list[dict[str, NodeRunResult]], list[str | None]]:
|
||||
"""Execute the evaluation target for every test-data item in parallel.
|
||||
|
||||
:param tenant_id: Workspace / tenant ID.
|
||||
:param target_type: ``"app"`` or ``"snippet"``.
|
||||
:param target_id: ID of the App or CustomizedSnippet.
|
||||
:param input_list: All test-data items parsed from the dataset.
|
||||
:param max_workers: Maximum number of parallel worker threads.
|
||||
:return: Tuple of (node_results, workflow_run_ids).
|
||||
node_results: ordered list of ``{node_id: NodeRunResult}`` mappings;
|
||||
the *i*-th element corresponds to ``input_list[i]``.
|
||||
workflow_run_ids: ordered list of workflow_run_id strings (or None)
|
||||
for each input item.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
flask_app: Flask = current_app._get_current_object() # type: ignore
|
||||
|
||||
def _worker(item: EvaluationDatasetInput) -> tuple[dict[str, NodeRunResult], str | None]:
|
||||
with flask_app.app_context():
|
||||
from models.engine import db
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as thread_session:
|
||||
try:
|
||||
response = cls._run_single_target(
|
||||
session=thread_session,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
item=item,
|
||||
)
|
||||
|
||||
workflow_run_id = cls._extract_workflow_run_id(response)
|
||||
if not workflow_run_id:
|
||||
logger.warning(
|
||||
"No workflow_run_id for item %d (target=%s)",
|
||||
item.index,
|
||||
target_id,
|
||||
)
|
||||
return {}, None
|
||||
|
||||
node_results = cls._query_node_run_results(
|
||||
session=thread_session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=target_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return node_results, workflow_run_id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Target execution failed for item %d (target=%s)",
|
||||
item.index,
|
||||
target_id,
|
||||
)
|
||||
return {}, None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_worker, item) for item in input_list]
|
||||
ordered_results: list[dict[str, NodeRunResult]] = []
|
||||
ordered_workflow_run_ids: list[str | None] = []
|
||||
for future in futures:
|
||||
try:
|
||||
node_result, wf_run_id = future.result()
|
||||
ordered_results.append(node_result)
|
||||
ordered_workflow_run_ids.append(wf_run_id)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error collecting target execution result")
|
||||
ordered_results.append({})
|
||||
ordered_workflow_run_ids.append(None)
|
||||
|
||||
return ordered_results, ordered_workflow_run_ids
|
||||
|
||||
@classmethod
|
||||
def _run_single_target(
|
||||
cls,
|
||||
session: Session,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
item: EvaluationDatasetInput,
|
||||
) -> Mapping[str, object]:
|
||||
"""Execute a single evaluation target with one test-data item.
|
||||
|
||||
Dispatches to the appropriate execution service based on
|
||||
``target_type``:
|
||||
|
||||
* ``"snippet"`` → :meth:`SnippetGenerateService.run_published`
|
||||
* ``"app"`` → :meth:`WorkflowAppGenerator().generate` (blocking mode)
|
||||
|
||||
:returns: The blocking response mapping from the workflow engine.
|
||||
:raises ValueError: If the target is not found or not published.
|
||||
"""
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.evaluation.runners import get_service_account_for_app, get_service_account_for_snippet
|
||||
|
||||
if target_type == "snippet":
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.query(CustomizedSnippet).filter_by(id=target_id).first()
|
||||
if not snippet:
|
||||
raise ValueError(f"Snippet {target_id} not found")
|
||||
|
||||
service_account = get_service_account_for_snippet(session, target_id)
|
||||
|
||||
return SnippetGenerateService.run_published(
|
||||
snippet=snippet,
|
||||
user=service_account,
|
||||
args={"inputs": item.inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
else:
|
||||
# target_type == "app"
|
||||
app = session.query(App).filter_by(id=target_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App {target_id} not found")
|
||||
|
||||
service_account = get_service_account_for_app(session, target_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError(f"No published workflow for app {target_id}")
|
||||
|
||||
response: Mapping[str, object] = WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=service_account,
|
||||
args={"inputs": item.inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _extract_workflow_run_id(response: Mapping[str, object]) -> str | None:
|
||||
"""Extract ``workflow_run_id`` from a blocking workflow response."""
|
||||
wf_run_id = response.get("workflow_run_id")
|
||||
if wf_run_id:
|
||||
return str(wf_run_id)
|
||||
data = response.get("data")
|
||||
if isinstance(data, Mapping) and data.get("id"):
|
||||
return str(data["id"])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _query_node_run_results(
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> dict[str, NodeRunResult]:
|
||||
"""Query all node execution records for a workflow run."""
|
||||
from sqlalchemy import asc, select
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
stmt = (
|
||||
WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel))
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(asc(WorkflowNodeExecutionModel.created_at))
|
||||
)
|
||||
|
||||
node_models: list[WorkflowNodeExecutionModel] = list(session.execute(stmt).scalars().all())
|
||||
|
||||
result: dict[str, NodeRunResult] = {}
|
||||
for node in node_models:
|
||||
# Convert string-keyed metadata to WorkflowNodeExecutionMetadataKey-keyed
|
||||
raw_metadata = node.execution_metadata_dict
|
||||
typed_metadata: dict[WorkflowNodeExecutionMetadataKey, object] = {}
|
||||
for key, val in raw_metadata.items():
|
||||
try:
|
||||
typed_metadata[WorkflowNodeExecutionMetadataKey(key)] = val
|
||||
except ValueError:
|
||||
pass # skip unknown metadata keys
|
||||
|
||||
result[node.node_id] = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus(node.status),
|
||||
inputs=node.inputs_dict or {},
|
||||
process_data=node.process_data_dict or {},
|
||||
outputs=node.outputs_dict or {},
|
||||
metadata=typed_metadata,
|
||||
error=node.error or "",
|
||||
)
|
||||
return result
|
||||
|
||||
# ---- Dataset Parsing ----
|
||||
|
||||
@classmethod
|
||||
def _parse_dataset(cls, xlsx_content: bytes) -> list[EvaluationDatasetInput]:
|
||||
"""Parse evaluation dataset from XLSX bytes."""
|
||||
wb = load_workbook(io.BytesIO(xlsx_content), read_only=True)
|
||||
ws = wb.active
|
||||
if ws is None:
|
||||
raise EvaluationDatasetInvalidError("XLSX file has no active worksheet.")
|
||||
|
||||
rows = list(ws.iter_rows(values_only=True))
|
||||
if len(rows) < 2:
|
||||
raise EvaluationDatasetInvalidError("Dataset must have at least a header row and one data row.")
|
||||
|
||||
headers = [str(h).strip() if h is not None else "" for h in rows[0]]
|
||||
if not headers or headers[0].lower() != "index":
|
||||
raise EvaluationDatasetInvalidError("First column header must be 'index'.")
|
||||
|
||||
input_headers = headers[1:] # Skip 'index'
|
||||
items = []
|
||||
for row_idx, row in enumerate(rows[1:], start=1):
|
||||
values = list(row)
|
||||
if all(v is None or str(v).strip() == "" for v in values):
|
||||
continue # Skip empty rows
|
||||
|
||||
index_val = values[0] if values else row_idx
|
||||
try:
|
||||
index = int(str(index_val))
|
||||
except (TypeError, ValueError):
|
||||
index = row_idx
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for col_idx, header in enumerate(input_headers):
|
||||
val = values[col_idx + 1] if col_idx + 1 < len(values) else None
|
||||
inputs[header] = str(val) if val is not None else ""
|
||||
|
||||
# Extract expected_output column into dedicated field
|
||||
expected_output = inputs.pop("expected_output", None)
|
||||
|
||||
items.append(
|
||||
EvaluationDatasetInput(
|
||||
index=index,
|
||||
inputs=inputs,
|
||||
expected_output=expected_output,
|
||||
)
|
||||
)
|
||||
|
||||
wb.close()
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def execute_retrieval_test_targets(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
account_id: str,
|
||||
input_list: list[EvaluationDatasetInput],
|
||||
max_workers: int = 5,
|
||||
) -> list[NodeRunResult]:
|
||||
"""Run hit testing against a knowledge base for every input item in parallel.
|
||||
|
||||
Each item must supply a ``query`` key in its ``inputs`` dict. The
|
||||
retrieved segments are normalised into the same ``NodeRunResult`` format
|
||||
that :class:`RetrievalEvaluationRunner` expects:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
NodeRunResult(
|
||||
inputs={"query": "..."},
|
||||
outputs={"result": [{"content": "...", "score": ...}, ...]},
|
||||
)
|
||||
|
||||
:returns: Ordered list of ``NodeRunResult`` — one per input item.
|
||||
If retrieval fails for an item the result has an empty ``result``
|
||||
list so the runner can still persist a (metric-less) row.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from flask import current_app
|
||||
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
def _worker(item: EvaluationDatasetInput) -> NodeRunResult:
|
||||
with flask_app.app_context():
|
||||
from extensions.ext_database import db as flask_db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
dataset = flask_db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
account = flask_db.session.query(Account).filter_by(id=account_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {account_id} not found")
|
||||
|
||||
query = str(item.inputs.get("query", ""))
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
account=account,
|
||||
retrieval_model=None, # Use dataset's configured retrieval model
|
||||
external_retrieval_model={},
|
||||
limit=10,
|
||||
)
|
||||
|
||||
records = response.get("records", [])
|
||||
result_list = [
|
||||
{
|
||||
"content": r.get("segment", {}).get("content", "") or r.get("content", ""),
|
||||
"score": r.get("score"),
|
||||
}
|
||||
for r in records
|
||||
if r.get("segment", {}).get("content") or r.get("content")
|
||||
]
|
||||
|
||||
return NodeRunResult(
|
||||
inputs={"query": query},
|
||||
outputs={"result": result_list},
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_worker, item) for item in input_list]
|
||||
results: list[NodeRunResult] = []
|
||||
for item, future in zip(input_list, futures):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception:
|
||||
logger.exception("Retrieval test failed for item %d (dataset=%s)", item.index, dataset_id)
|
||||
results.append(NodeRunResult(inputs={}, outputs={"result": []}))
|
||||
|
||||
return results
|
||||
@ -281,7 +281,7 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
|
||||
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
|
||||
233
api/services/quota_service.py
Normal file
233
api/services/quota_service.py
Normal file
@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota reservation (Reserve phase).
|
||||
|
||||
Lifecycle:
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
|
||||
try:
|
||||
do_work()
|
||||
charge.commit() # Confirm consumption
|
||||
except:
|
||||
charge.refund() # Release frozen quota
|
||||
|
||||
If neither commit() nor refund() is called, the billing system's
|
||||
cleanup CronJob will auto-release the reservation within ~75 seconds.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None # reservation_id
|
||||
_quota_type: QuotaType
|
||||
_tenant_id: str | None = None
|
||||
_feature_key: str | None = None
|
||||
_amount: int = 0
|
||||
_committed: bool = field(default=False, repr=False)
|
||||
|
||||
def commit(self, actual_amount: int | None = None) -> None:
|
||||
"""
|
||||
Confirm the consumption with actual amount.
|
||||
|
||||
Args:
|
||||
actual_amount: Actual amount consumed. Defaults to the reserved amount.
|
||||
If less than reserved, the difference is refunded automatically.
|
||||
"""
|
||||
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
amount = actual_amount if actual_amount is not None else self._amount
|
||||
BillingService.quota_commit(
|
||||
tenant_id=self._tenant_id,
|
||||
feature_key=self._feature_key,
|
||||
reservation_id=self.charge_id,
|
||||
actual_amount=amount,
|
||||
)
|
||||
self._committed = True
|
||||
logger.debug(
|
||||
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
|
||||
self._quota_type,
|
||||
self._tenant_id,
|
||||
self.charge_id,
|
||||
amount,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Release the reserved quota (cancel the charge).
|
||||
|
||||
Safe to call even if:
|
||||
- charge failed or was disabled (charge_id is None)
|
||||
- already committed (Release after Commit is a no-op)
|
||||
- already refunded (idempotent)
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
|
||||
|
||||
@staticmethod
|
||||
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve + immediate Commit (one-shot mode).
|
||||
|
||||
The returned QuotaCharge supports .refund() which calls Release.
|
||||
For two-phase usage (e.g. streaming), use reserve() directly.
|
||||
"""
|
||||
charge = QuotaService.reserve(quota_type, tenant_id, amount)
|
||||
if charge.success and charge.charge_id:
|
||||
charge.commit()
|
||||
return charge
|
||||
|
||||
@staticmethod
|
||||
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve quota before task execution (Reserve phase only).
|
||||
|
||||
The caller MUST call charge.commit() after the task succeeds,
|
||||
or charge.refund() if the task fails.
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
|
||||
|
||||
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to reserve must be greater than 0")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
feature_key = quota_type.billing_key
|
||||
|
||||
try:
|
||||
reserve_resp = BillingService.quota_reserve(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
request_id=request_id,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
reservation_id = reserve_resp.get("reservation_id")
|
||||
if not reservation_id:
|
||||
logger.warning(
|
||||
"Reserve returned no reservation_id for %s, feature %s, response: %s",
|
||||
tenant_id,
|
||||
quota_type.value,
|
||||
reserve_resp,
|
||||
)
|
||||
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
logger.debug(
|
||||
"Reserved %d %s quota for tenant %s, reservation_id: %s",
|
||||
amount,
|
||||
quota_type.value,
|
||||
tenant_id,
|
||||
reservation_id,
|
||||
)
|
||||
return QuotaCharge(
|
||||
success=True,
|
||||
charge_id=reservation_id,
|
||||
_quota_type=quota_type,
|
||||
_tenant_id=tenant_id,
|
||||
_feature_key=feature_key,
|
||||
_amount=amount,
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return unlimited()
|
||||
|
||||
@staticmethod
|
||||
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = QuotaService.get_remaining(quota_type, tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
|
||||
"""Release a reservation. Guarantees no exceptions."""
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not reservation_id:
|
||||
return
|
||||
|
||||
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
|
||||
BillingService.quota_release(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
reservation_id=reservation_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
|
||||
|
||||
@staticmethod
|
||||
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_quota_info(tenant_id)
|
||||
if isinstance(usage_info, dict):
|
||||
feature_info = usage_info.get(quota_type.billing_key, {})
|
||||
if isinstance(feature_info, dict):
|
||||
limit = feature_info.get("limit", 0)
|
||||
usage = feature_info.get("usage", 0)
|
||||
if limit == -1:
|
||||
return -1
|
||||
return max(0, limit - usage)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return -1
|
||||
555
api/services/snippet_dsl_service.py
Normal file
555
api/services/snippet_dsl_service.py
Normal file
@ -0,0 +1,555 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml # type: ignore
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet, SnippetType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.snippet_service import SNIPPET_FORBIDDEN_NODE_TYPES, SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMPORT_INFO_REDIS_KEY_PREFIX = "snippet_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "snippet_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.1.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class SnippetImportInfo(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
snippet_id: str | None = None
|
||||
current_dsl_version: str = CURRENT_DSL_VERSION
|
||||
imported_dsl_version: str = ""
|
||||
error: str = ""
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
current_ver = version.parse(CURRENT_DSL_VERSION)
|
||||
imported_ver = version.parse(imported_version)
|
||||
except version.InvalidVersion:
|
||||
return ImportStatus.FAILED
|
||||
|
||||
# If imported version is newer than current, always return PENDING
|
||||
if imported_ver > current_ver:
|
||||
return ImportStatus.PENDING
|
||||
|
||||
# If imported version is older than current's major, return PENDING
|
||||
if imported_ver.major < current_ver.major:
|
||||
return ImportStatus.PENDING
|
||||
|
||||
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
|
||||
if imported_ver.minor < current_ver.minor:
|
||||
return ImportStatus.COMPLETED_WITH_WARNINGS
|
||||
|
||||
# If imported version equals or is older than current's micro, return COMPLETED
|
||||
return ImportStatus.COMPLETED
|
||||
|
||||
|
||||
class SnippetPendingData(BaseModel):
|
||||
import_mode: str
|
||||
yaml_content: str
|
||||
snippet_id: str | None
|
||||
|
||||
|
||||
class CheckDependenciesPendingData(BaseModel):
|
||||
dependencies: list[PluginDependency]
|
||||
snippet_id: str | None
|
||||
|
||||
|
||||
class SnippetDslService:
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def import_snippet(
|
||||
self,
|
||||
*,
|
||||
account: Account,
|
||||
import_mode: str,
|
||||
yaml_content: str | None = None,
|
||||
yaml_url: str | None = None,
|
||||
snippet_id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> SnippetImportInfo:
|
||||
"""Import a snippet from YAML content or URL."""
|
||||
import_id = str(uuid.uuid4())
|
||||
|
||||
# Validate import mode
|
||||
try:
|
||||
mode = ImportMode(import_mode)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||
|
||||
# Get YAML content
|
||||
content: str = ""
|
||||
if mode == ImportMode.YAML_URL:
|
||||
if not yaml_url:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="yaml_url is required when import_mode is yaml-url",
|
||||
)
|
||||
try:
|
||||
parsed_url = urlparse(yaml_url)
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Invalid URL scheme, only http and https are allowed",
|
||||
)
|
||||
response = ssrf_proxy.get(yaml_url, timeout=(10, 30))
|
||||
if response.status_code != 200:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Failed to fetch YAML from URL: {response.status_code}",
|
||||
)
|
||||
content = response.text
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch YAML from URL")
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Failed to fetch YAML from URL: {str(e)}",
|
||||
)
|
||||
elif mode == ImportMode.YAML_CONTENT:
|
||||
if not yaml_content:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="yaml_content is required when import_mode is yaml-content",
|
||||
)
|
||||
content = yaml_content
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes",
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse YAML
|
||||
data = yaml.safe_load(content)
|
||||
if not isinstance(data, dict):
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Invalid YAML format: expected a dictionary",
|
||||
)
|
||||
|
||||
# Validate and fix DSL version
|
||||
if not data.get("version"):
|
||||
data["version"] = "0.1.0"
|
||||
|
||||
# Strictly validate kind field
|
||||
kind = data.get("kind")
|
||||
if not kind:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Missing 'kind' field in DSL. Expected 'kind: snippet'.",
|
||||
)
|
||||
if kind != "snippet":
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Invalid DSL kind: expected 'snippet', got '{kind}'. This DSL is for {kind}, not snippet.",
|
||||
)
|
||||
|
||||
imported_version = data.get("version", "0.1.0")
|
||||
if not isinstance(imported_version, str):
|
||||
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
|
||||
status = _check_version_compatibility(imported_version)
|
||||
|
||||
# Extract snippet data
|
||||
snippet_data = data.get("snippet")
|
||||
if not snippet_data:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Missing snippet data in YAML content",
|
||||
)
|
||||
|
||||
# Validate workflow nodes - check for forbidden node types
|
||||
workflow_data = data.get("workflow", {})
|
||||
if workflow_data:
|
||||
graph = workflow_data.get("graph", {})
|
||||
nodes = graph.get("nodes", [])
|
||||
forbidden_nodes_found = []
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if not node_data:
|
||||
continue
|
||||
node_type = node_data.get("type", "")
|
||||
if node_type in SNIPPET_FORBIDDEN_NODE_TYPES:
|
||||
forbidden_nodes_found.append(node_type)
|
||||
|
||||
if forbidden_nodes_found:
|
||||
forbidden_types_str = ", ".join(set(forbidden_nodes_found))
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Snippet cannot contain the following node types: {forbidden_types_str}",
|
||||
)
|
||||
|
||||
# If snippet_id is provided, check if it exists
|
||||
snippet = None
|
||||
if snippet_id:
|
||||
stmt = select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == snippet_id,
|
||||
CustomizedSnippet.tenant_id == account.current_tenant_id,
|
||||
)
|
||||
snippet = self._session.scalar(stmt)
|
||||
|
||||
if not snippet:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Snippet not found",
|
||||
)
|
||||
|
||||
# If major version mismatch, store import info in Redis
|
||||
if status == ImportStatus.PENDING:
|
||||
pending_data = SnippetPendingData(
|
||||
import_mode=import_mode,
|
||||
yaml_content=content,
|
||||
snippet_id=snippet_id,
|
||||
)
|
||||
redis_client.setex(
|
||||
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
pending_data.model_dump_json(),
|
||||
)
|
||||
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=status,
|
||||
snippet_id=snippet_id,
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = data.get("dependencies", [])
|
||||
check_dependencies_pending_data = None
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
|
||||
# Create or update snippet
|
||||
snippet = self._create_or_update_snippet(
|
||||
snippet=snippet,
|
||||
data=data,
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
dependencies=check_dependencies_pending_data,
|
||||
)
|
||||
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=status,
|
||||
snippet_id=snippet.id,
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
except yaml.YAMLError as e:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Invalid YAML format: {str(e)}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to import snippet")
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def confirm_import(self, *, import_id: str, account: Account) -> SnippetImportInfo:
|
||||
"""
|
||||
Confirm an import that requires confirmation
|
||||
"""
|
||||
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
|
||||
pending_data = redis_client.get(redis_key)
|
||||
|
||||
if not pending_data:
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Import information expired or does not exist",
|
||||
)
|
||||
|
||||
try:
|
||||
if not isinstance(pending_data, str | bytes):
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Invalid import information",
|
||||
)
|
||||
|
||||
pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data
|
||||
pending = SnippetPendingData.model_validate_json(pending_data_str)
|
||||
|
||||
# Re-import with the pending data
|
||||
return self.import_snippet(
|
||||
account=account,
|
||||
import_mode=pending.import_mode,
|
||||
yaml_content=pending.yaml_content,
|
||||
snippet_id=pending.snippet_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to confirm import")
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def check_dependencies(self, snippet: CustomizedSnippet) -> CheckDependenciesResult:
|
||||
"""
|
||||
Check dependencies for a snippet
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
return CheckDependenciesResult(leaked_dependencies=[])
|
||||
|
||||
dependencies = self._extract_dependencies_from_workflow(workflow)
|
||||
leaked_dependencies = DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=snippet.tenant_id, dependencies=dependencies
|
||||
)
|
||||
|
||||
return CheckDependenciesResult(leaked_dependencies=leaked_dependencies)
|
||||
|
||||
def _create_or_update_snippet(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet | None,
|
||||
data: dict,
|
||||
account: Account,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
dependencies: list[PluginDependency] | None = None,
|
||||
) -> CustomizedSnippet:
|
||||
"""
|
||||
Create or update snippet from DSL data
|
||||
"""
|
||||
snippet_data = data.get("snippet", {})
|
||||
workflow_data = data.get("workflow", {})
|
||||
|
||||
# Extract snippet info
|
||||
snippet_name = name or snippet_data.get("name") or "Untitled Snippet"
|
||||
snippet_description = description or snippet_data.get("description") or ""
|
||||
snippet_type_str = snippet_data.get("type", "node")
|
||||
try:
|
||||
snippet_type = SnippetType(snippet_type_str)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
icon_info = snippet_data.get("icon_info", {})
|
||||
input_fields = snippet_data.get("input_fields", [])
|
||||
|
||||
# Create or update snippet
|
||||
if snippet:
|
||||
# Update existing snippet
|
||||
snippet.name = snippet_name
|
||||
snippet.description = snippet_description
|
||||
snippet.type = snippet_type.value
|
||||
snippet.icon_info = icon_info or None
|
||||
snippet.input_fields = json.dumps(input_fields) if input_fields else None
|
||||
snippet.updated_by = account.id
|
||||
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
else:
|
||||
# Create new snippet
|
||||
snippet = CustomizedSnippet(
|
||||
tenant_id=account.current_tenant_id,
|
||||
name=snippet_name,
|
||||
description=snippet_description,
|
||||
type=snippet_type.value,
|
||||
icon_info=icon_info or None,
|
||||
input_fields=json.dumps(input_fields) if input_fields else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
self._session.add(snippet)
|
||||
self._session.flush()
|
||||
|
||||
# Create or update draft workflow
|
||||
if workflow_data:
|
||||
graph = workflow_data.get("graph", {})
|
||||
|
||||
snippet_service = SnippetService()
|
||||
# Get existing workflow hash if exists
|
||||
existing_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
unique_hash = existing_workflow.unique_hash if existing_workflow else None
|
||||
|
||||
snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=graph,
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
input_fields=input_fields,
|
||||
)
|
||||
|
||||
self._session.commit()
|
||||
return snippet
|
||||
|
||||
def export_snippet_dsl(self, snippet: CustomizedSnippet, include_secret: bool = False) -> str:
|
||||
"""
|
||||
Export snippet as DSL
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param include_secret: Whether include secret variable
|
||||
:return: YAML string
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
icon_info = snippet.icon_info or {}
|
||||
export_data = {
|
||||
"version": CURRENT_DSL_VERSION,
|
||||
"kind": "snippet",
|
||||
"snippet": {
|
||||
"name": snippet.name,
|
||||
"description": snippet.description or "",
|
||||
"type": snippet.type,
|
||||
"icon_info": icon_info,
|
||||
"input_fields": snippet.input_fields_list,
|
||||
},
|
||||
}
|
||||
|
||||
self._append_workflow_export_data(
|
||||
export_data=export_data, snippet=snippet, workflow=workflow, include_secret=include_secret
|
||||
)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
def _append_workflow_export_data(
|
||||
self, *, export_data: dict, snippet: CustomizedSnippet, workflow: Workflow, include_secret: bool
|
||||
) -> None:
|
||||
"""
|
||||
Append workflow export data
|
||||
"""
|
||||
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||
# Filter workspace related data from nodes
|
||||
workflow_dict["environment_variables"] = []
|
||||
workflow_dict["conversation_variables"] = []
|
||||
|
||||
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||
node_data = node.get("data", {})
|
||||
if not node_data:
|
||||
continue
|
||||
data_type = node_data.get("type", "")
|
||||
if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node_data.get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
self._encrypt_dataset_id(dataset_id=dataset_id, tenant_id=snippet.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
# filter credential id from tool node
|
||||
if not include_secret and data_type == BuiltinNodeTypes.TOOL:
|
||||
node_data.pop("credential_id", None)
|
||||
# filter credential id from agent node
|
||||
if not include_secret and data_type == BuiltinNodeTypes.AGENT:
|
||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||
tool.pop("credential_id", None)
|
||||
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = self._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=snippet.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
def _encrypt_dataset_id(self, *, dataset_id: str, tenant_id: str) -> str:
|
||||
"""
|
||||
Encrypt dataset ID for export
|
||||
"""
|
||||
# For now, just return the dataset_id as-is
|
||||
# In the future, we might want to encrypt it
|
||||
return dataset_id
|
||||
|
||||
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow
|
||||
:param workflow: Workflow instance
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
dependencies = self._extract_dependencies_from_workflow_graph(graph)
|
||||
return dependencies
|
||||
|
||||
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow graph
|
||||
:param graph: Workflow graph
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
for node in graph.get("nodes", []):
|
||||
node_data = node.get("data", {})
|
||||
if not node_data:
|
||||
continue
|
||||
data_type = node_data.get("type", "")
|
||||
if data_type == BuiltinNodeTypes.TOOL:
|
||||
tool_config = node_data.get("tool_configurations", {})
|
||||
provider_type = tool_config.get("provider_type")
|
||||
provider_name = tool_config.get("provider")
|
||||
if provider_type and provider_name:
|
||||
dependencies.append(f"{provider_name}/{provider_name}")
|
||||
elif data_type == BuiltinNodeTypes.AGENT:
|
||||
agent_parameters = node_data.get("agent_parameters", {})
|
||||
tools = agent_parameters.get("tools", {}).get("value", [])
|
||||
for tool in tools:
|
||||
provider_type = tool.get("provider_type")
|
||||
provider_name = tool.get("provider")
|
||||
if provider_type and provider_name:
|
||||
dependencies.append(f"{provider_name}/{provider_name}")
|
||||
|
||||
return dependencies
|
||||
421
api/services/snippet_generate_service.py
Normal file
421
api/services/snippet_generate_service.py
Normal file
@ -0,0 +1,421 @@
|
||||
"""
|
||||
Service for generating snippet workflow executions.
|
||||
|
||||
Uses an adapter pattern to bridge CustomizedSnippet with the App-based
|
||||
WorkflowAppGenerator. The adapter (_SnippetAsApp) provides the minimal App-like
|
||||
interface needed by the generator, avoiding modifications to core workflow
|
||||
infrastructure.
|
||||
|
||||
Key invariants:
|
||||
- Snippets always run as WORKFLOW mode (not CHAT or ADVANCED_CHAT).
|
||||
- The adapter maps snippet.id to app_id in workflow execution records.
|
||||
- Snippet debugging has no rate limiting (max_active_requests = 0).
|
||||
|
||||
Supported execution modes:
|
||||
- Full workflow run (generate): Runs the entire draft workflow as SSE stream.
|
||||
- Single node run (run_draft_node): Synchronous single-step debugging for regular nodes.
|
||||
- Single iteration run (generate_single_iteration): SSE stream for iteration container nodes.
|
||||
- Single loop run (generate_single_loop): SSE stream for loop container nodes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import make_transient
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from factories import file_factory
|
||||
from graphon.file.models import File
|
||||
from models import Account
|
||||
from models.model import AppMode, EndUser
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SnippetAsApp:
|
||||
"""
|
||||
Minimal adapter that wraps a CustomizedSnippet to satisfy the App-like
|
||||
interface required by WorkflowAppGenerator, WorkflowAppConfigManager,
|
||||
and WorkflowService.run_draft_workflow_node.
|
||||
|
||||
Used properties:
|
||||
- id: maps to snippet.id (stored as app_id in workflows table)
|
||||
- tenant_id: maps to snippet.tenant_id
|
||||
- mode: hardcoded to AppMode.WORKFLOW since snippets always run as workflows
|
||||
- max_active_requests: defaults to 0 (no limit) for snippet debugging
|
||||
- app_model_config_id: None (snippets don't have app model configs)
|
||||
"""
|
||||
|
||||
id: str
|
||||
tenant_id: str
|
||||
mode: str
|
||||
max_active_requests: int
|
||||
app_model_config_id: str | None
|
||||
|
||||
def __init__(self, snippet: CustomizedSnippet) -> None:
|
||||
self.id = snippet.id
|
||||
self.tenant_id = snippet.tenant_id
|
||||
self.mode = AppMode.WORKFLOW.value
|
||||
self.max_active_requests = 0
|
||||
self.app_model_config_id = None
|
||||
|
||||
|
||||
class SnippetGenerateService:
|
||||
"""
|
||||
Service for running snippet workflow executions.
|
||||
|
||||
Adapts CustomizedSnippet to work with the existing App-based
|
||||
WorkflowAppGenerator infrastructure, avoiding duplication of the
|
||||
complex workflow execution pipeline.
|
||||
"""
|
||||
|
||||
# Specific ID for the injected virtual Start node so it can be recognised
|
||||
_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
|
||||
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Run a snippet's draft workflow.
|
||||
|
||||
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
|
||||
then delegates execution to WorkflowAppGenerator.
|
||||
|
||||
If the workflow graph has no Start node, a virtual Start node is injected
|
||||
in-memory so that:
|
||||
1. Graph validation passes (root node must have execution_type=ROOT).
|
||||
2. User inputs are processed into the variable pool by the StartNode logic.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param args: Workflow inputs (must include "inputs" key)
|
||||
:param invoke_from: Source of invocation (typically DEBUGGER)
|
||||
:param streaming: Whether to stream the response
|
||||
:return: Blocking response mapping or SSE streaming generator
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# Inject a virtual Start node when the graph doesn't have one.
|
||||
workflow = cls._ensure_start_node(workflow, snippet)
|
||||
|
||||
# Adapt snippet to App-like interface for WorkflowAppGenerator
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def run_published(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Run a snippet's published workflow in non-streaming (blocking) mode.
|
||||
|
||||
Similar to :meth:`generate` but targets the published workflow instead
|
||||
of the draft, and returns the raw blocking response without SSE
|
||||
wrapping. Designed for programmatic callers such as evaluation runners.
|
||||
|
||||
:param snippet: CustomizedSnippet instance (must be published)
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param args: Workflow inputs (must include "inputs" key)
|
||||
:param invoke_from: Source of invocation
|
||||
:return: Blocking response mapping with workflow outputs
|
||||
:raises ValueError: If the snippet has no published workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet)
|
||||
if not workflow:
|
||||
raise ValueError("No published workflow found for snippet")
|
||||
|
||||
# Inject a virtual Start node when the graph doesn't have one.
|
||||
workflow = cls._ensure_start_node(workflow, snippet)
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
response: Mapping[str, Any] = WorkflowAppGenerator().generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=False,
|
||||
)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def ensure_start_node_for_worker(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
|
||||
"""Public wrapper for worker-thread start-node injection."""
|
||||
return cls._ensure_start_node(workflow, snippet)
|
||||
|
||||
@classmethod
|
||||
def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
|
||||
"""
|
||||
Return *workflow* with a Start node.
|
||||
|
||||
If the graph already contains a Start node, the original workflow is
|
||||
returned unchanged. Otherwise a virtual Start node is injected and the
|
||||
workflow object is detached from the SQLAlchemy session so the in-memory
|
||||
change is never flushed to the database.
|
||||
"""
|
||||
graph_dict = workflow.graph_dict
|
||||
nodes: list[dict[str, Any]] = graph_dict.get("nodes", [])
|
||||
|
||||
has_start = any(node.get("data", {}).get("type") == "start" for node in nodes)
|
||||
if has_start:
|
||||
return workflow
|
||||
|
||||
modified_graph = cls._inject_virtual_start_node(
|
||||
graph_dict=graph_dict,
|
||||
input_fields=snippet.input_fields_list,
|
||||
)
|
||||
|
||||
# Detach from session to prevent accidental DB persistence of the
|
||||
# modified graph. All attributes remain accessible for read.
|
||||
make_transient(workflow)
|
||||
workflow.graph = json.dumps(modified_graph)
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def _inject_virtual_start_node(
|
||||
cls,
|
||||
graph_dict: Mapping[str, Any],
|
||||
input_fields: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build a new graph dict with a virtual Start node prepended.
|
||||
|
||||
The virtual Start node is wired to every existing node that has no
|
||||
incoming edges (i.e. the current root candidates). This guarantees:
|
||||
|
||||
:param graph_dict: Original graph configuration.
|
||||
:param input_fields: Snippet input field definitions from
|
||||
``CustomizedSnippet.input_fields_list``.
|
||||
:return: New graph dict containing the virtual Start node and edges.
|
||||
"""
|
||||
nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", []))
|
||||
edges: list[dict[str, Any]] = list(graph_dict.get("edges", []))
|
||||
|
||||
# Identify nodes with no incoming edges.
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge in edges:
|
||||
target = edge.get("target")
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming]
|
||||
|
||||
# Build Start node ``variables`` from snippet input fields.
|
||||
start_variables: list[dict[str, Any]] = []
|
||||
for field in input_fields:
|
||||
var: dict[str, Any] = {
|
||||
"variable": field.get("variable", ""),
|
||||
"label": field.get("label", field.get("variable", "")),
|
||||
"type": field.get("type", "text-input"),
|
||||
"required": field.get("required", False),
|
||||
"options": field.get("options", []),
|
||||
}
|
||||
if field.get("max_length") is not None:
|
||||
var["max_length"] = field["max_length"]
|
||||
start_variables.append(var)
|
||||
|
||||
virtual_start_node: dict[str, Any] = {
|
||||
"id": cls._VIRTUAL_START_NODE_ID,
|
||||
"data": {
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"variables": start_variables,
|
||||
},
|
||||
}
|
||||
|
||||
# Create edges from virtual Start to each root candidate.
|
||||
new_edges: list[dict[str, Any]] = [
|
||||
{
|
||||
"source": cls._VIRTUAL_START_NODE_ID,
|
||||
"sourceHandle": "source",
|
||||
"target": root_id,
|
||||
"targetHandle": "target",
|
||||
}
|
||||
for root_id in root_candidate_ids
|
||||
]
|
||||
|
||||
return {
|
||||
**graph_dict,
|
||||
"nodes": [virtual_start_node, *nodes],
|
||||
"edges": [*edges, *new_edges],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def run_draft_node(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
node_id: str,
|
||||
user_inputs: Mapping[str, Any],
|
||||
account: Account,
|
||||
query: str = "",
|
||||
files: Sequence[File] | None = None,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Run a single node in a snippet's draft workflow (single-step debugging).
|
||||
|
||||
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
|
||||
parses file inputs, then delegates to WorkflowService.run_draft_workflow_node.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param node_id: ID of the node to run
|
||||
:param user_inputs: User input values for the node
|
||||
:param account: Account initiating the run
|
||||
:param query: Optional query string
|
||||
:param files: Optional parsed file objects
|
||||
:return: WorkflowNodeExecutionModel with execution results
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.run_draft_workflow_node(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=account,
|
||||
query=query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
node_id: str,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Run a single iteration node in a snippet's draft workflow.
|
||||
|
||||
Iteration nodes are container nodes that execute their sub-graph multiple
|
||||
times, producing many events. Therefore, this uses the full WorkflowAppGenerator
|
||||
pipeline with SSE streaming (unlike regular single-step node run).
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param node_id: ID of the iteration node to run
|
||||
:param args: Dict containing 'inputs' key with iteration input data
|
||||
:param streaming: Whether to stream the response (should be True)
|
||||
:return: SSE streaming generator
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
node_id: str,
|
||||
args: Any,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Run a single loop node in a snippet's draft workflow.
|
||||
|
||||
Loop nodes are container nodes that execute their sub-graph repeatedly,
|
||||
producing many events. Therefore, this uses the full WorkflowAppGenerator
|
||||
pipeline with SSE streaming (unlike regular single-step node run).
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param node_id: ID of the loop node to run
|
||||
:param args: Pydantic model with 'inputs' attribute containing loop input data
|
||||
:param streaming: Whether to stream the response (should be True)
|
||||
:return: SSE streaming generator
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args, # type: ignore[arg-type]
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_files(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
|
||||
"""
|
||||
Parse file mappings into File objects based on workflow configuration.
|
||||
|
||||
:param workflow: Workflow instance for file upload config
|
||||
:param files: Raw file mapping dicts
|
||||
:return: Parsed File objects
|
||||
"""
|
||||
files = files or []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config is None:
|
||||
return []
|
||||
return file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=workflow.tenant_id,
|
||||
config=file_extra_config,
|
||||
)
|
||||
608
api/services/snippet_service.py
Normal file
608
api/services/snippet_service.py
Normal file
@ -0,0 +1,608 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.snippet import CustomizedSnippet, SnippetType
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Node types not allowed in snippet workflows (sync, publish, DSL import).
|
||||
SNIPPET_FORBIDDEN_NODE_TYPES: frozenset[str] = frozenset(
|
||||
{
|
||||
BuiltinNodeTypes.START,
|
||||
BuiltinNodeTypes.HUMAN_INPUT,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SnippetService:
|
||||
"""Service for managing customized snippets."""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker | None = None):
|
||||
"""Initialize SnippetService with repository dependencies."""
|
||||
if session_maker is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@staticmethod
|
||||
def validate_snippet_graph_forbidden_nodes(graph: Mapping[str, Any]) -> None:
|
||||
"""Reject graphs that contain node types not allowed in snippets."""
|
||||
nodes = graph.get("nodes") or []
|
||||
disallowed: list[tuple[str, str]] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
node_data = node.get("data") or {}
|
||||
node_type = node_data.get("type")
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if node_type in SNIPPET_FORBIDDEN_NODE_TYPES:
|
||||
node_id = node.get("id")
|
||||
disallowed.append((str(node_id) if node_id is not None else "?", node_type))
|
||||
if not disallowed:
|
||||
return
|
||||
detail = ", ".join(f"{nid}:{t}" for nid, t in disallowed)
|
||||
raise ValueError(
|
||||
"Snippet workflow cannot contain start, human-input, or knowledge-retrieval nodes. "
|
||||
f"Found: {detail}"
|
||||
)
|
||||
|
||||
# --- CRUD Operations ---
|
||||
|
||||
@staticmethod
|
||||
def get_snippets(
|
||||
*,
|
||||
tenant_id: str,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
keyword: str | None = None,
|
||||
is_published: bool | None = None,
|
||||
creators: list[str] | None = None,
|
||||
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
|
||||
"""
|
||||
Get paginated list of snippets with optional search.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param page: Page number (1-indexed)
|
||||
:param limit: Number of items per page
|
||||
:param keyword: Optional search keyword for name/description
|
||||
:param is_published: Optional filter by published status (True/False/None for all)
|
||||
:param creators: Optional filter by creator account IDs
|
||||
:return: Tuple of (snippets list, total count, has_more flag)
|
||||
"""
|
||||
stmt = (
|
||||
select(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.tenant_id == tenant_id)
|
||||
.order_by(CustomizedSnippet.created_at.desc())
|
||||
)
|
||||
|
||||
if keyword:
|
||||
stmt = stmt.where(
|
||||
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
|
||||
)
|
||||
|
||||
if is_published is not None:
|
||||
stmt = stmt.where(CustomizedSnippet.is_published == is_published)
|
||||
|
||||
if creators:
|
||||
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
|
||||
|
||||
# Get total count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = db.session.scalar(count_stmt) or 0
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||
snippets = list(db.session.scalars(stmt).all())
|
||||
|
||||
has_more = len(snippets) > limit
|
||||
if has_more:
|
||||
snippets = snippets[:-1]
|
||||
|
||||
return snippets, total, has_more
|
||||
|
||||
@staticmethod
|
||||
def get_snippet_by_id(
|
||||
*,
|
||||
snippet_id: str,
|
||||
tenant_id: str,
|
||||
) -> CustomizedSnippet | None:
|
||||
"""
|
||||
Get snippet by ID with tenant isolation.
|
||||
|
||||
:param snippet_id: Snippet ID
|
||||
:param tenant_id: Tenant ID
|
||||
:return: CustomizedSnippet or None
|
||||
"""
|
||||
return (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.id == snippet_id,
|
||||
CustomizedSnippet.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
*,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
description: str | None,
|
||||
snippet_type: SnippetType,
|
||||
icon_info: dict | None,
|
||||
input_fields: list[dict] | None,
|
||||
account: Account,
|
||||
) -> CustomizedSnippet:
|
||||
"""
|
||||
Create a new snippet.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param name: Snippet name (must be unique per tenant)
|
||||
:param description: Snippet description
|
||||
:param snippet_type: Type of snippet (node or group)
|
||||
:param icon_info: Icon information
|
||||
:param input_fields: Input field definitions
|
||||
:param account: Creator account
|
||||
:return: Created CustomizedSnippet
|
||||
:raises ValueError: If name already exists
|
||||
"""
|
||||
# Check if name already exists for this tenant
|
||||
existing = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.tenant_id == tenant_id,
|
||||
CustomizedSnippet.name == name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Snippet with name '{name}' already exists")
|
||||
|
||||
snippet = CustomizedSnippet(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description or "",
|
||||
type=snippet_type.value,
|
||||
icon_info=icon_info,
|
||||
input_fields=json.dumps(input_fields) if input_fields else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(snippet)
|
||||
db.session.commit()
|
||||
|
||||
return snippet
|
||||
|
||||
@staticmethod
|
||||
def update_snippet(
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
account_id: str,
|
||||
data: dict,
|
||||
) -> CustomizedSnippet:
|
||||
"""
|
||||
Update snippet attributes.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: Snippet to update
|
||||
:param account_id: ID of account making the update
|
||||
:param data: Dictionary of fields to update
|
||||
:return: Updated CustomizedSnippet
|
||||
"""
|
||||
if "name" in data:
|
||||
# Check if new name already exists for this tenant
|
||||
existing = (
|
||||
session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.tenant_id == snippet.tenant_id,
|
||||
CustomizedSnippet.name == data["name"],
|
||||
CustomizedSnippet.id != snippet.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Snippet with name '{data['name']}' already exists")
|
||||
snippet.name = data["name"]
|
||||
|
||||
if "description" in data:
|
||||
snippet.description = data["description"]
|
||||
|
||||
if "icon_info" in data:
|
||||
snippet.icon_info = data["icon_info"]
|
||||
|
||||
snippet.updated_by = account_id
|
||||
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(snippet)
|
||||
return snippet
|
||||
|
||||
@staticmethod
|
||||
def delete_snippet(
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a snippet.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: Snippet to delete
|
||||
:return: True if deleted successfully
|
||||
"""
|
||||
session.delete(snippet)
|
||||
return True
|
||||
|
||||
# --- Workflow Operations ---
|
||||
|
||||
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: Draft Workflow or None
|
||||
"""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == snippet.tenant_id,
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: Published Workflow or None
|
||||
"""
|
||||
if not snippet.workflow_id:
|
||||
return None
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == snippet.tenant_id,
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.id == snippet.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
def sync_draft_workflow(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
graph: dict,
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
input_fields: list[dict] | None = None,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow for snippet.
|
||||
|
||||
Snippet workflows do not persist environment variables (always empty) or
|
||||
conversation variables (always empty).
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param graph: Workflow graph configuration
|
||||
:param unique_hash: Hash for conflict detection
|
||||
:param account: Account making the change
|
||||
:param input_fields: Input fields for snippet
|
||||
:return: Synced Workflow
|
||||
:raises WorkflowHashNotEqualError: If hash mismatch
|
||||
"""
|
||||
SnippetService.validate_snippet_graph_forbidden_nodes(graph)
|
||||
|
||||
workflow = self.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# Create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
features="{}",
|
||||
type=WorkflowType.SNIPPET.value,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
else:
|
||||
# Update existing draft workflow
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.environment_variables = []
|
||||
workflow.conversation_variables = []
|
||||
|
||||
# Update snippet's input_fields if provided
|
||||
if input_fields is not None:
|
||||
snippet.input_fields = json.dumps(input_fields)
|
||||
snippet.updated_by = account.id
|
||||
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
return workflow
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
account: Account,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Publish the draft workflow as a new version.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param account: Account making the change
|
||||
:return: Published Workflow
|
||||
:raises ValueError: If no draft workflow exists
|
||||
"""
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == snippet.tenant_id,
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
SnippetService.validate_snippet_graph_forbidden_nodes(draft_workflow.graph_dict)
|
||||
|
||||
# Create new published workflow
|
||||
workflow = Workflow.new(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
type=draft_workflow.type,
|
||||
version=str(datetime.now(UTC).replace(tzinfo=None)),
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||
marked_name="",
|
||||
marked_comment="",
|
||||
)
|
||||
session.add(workflow)
|
||||
|
||||
# Update snippet version
|
||||
snippet.version += 1
|
||||
snippet.is_published = True
|
||||
snippet.workflow_id = workflow.id
|
||||
snippet.updated_by = account.id
|
||||
session.add(snippet)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_all_published_workflows(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
page: int,
|
||||
limit: int,
|
||||
) -> tuple[Sequence[Workflow], bool]:
|
||||
"""
|
||||
Get all published workflow versions for snippet.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param page: Page number
|
||||
:param limit: Items per page
|
||||
:return: Tuple of (workflows list, has_more flag)
|
||||
"""
|
||||
if not snippet.workflow_id:
|
||||
return [], False
|
||||
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.version != "draft",
|
||||
)
|
||||
.order_by(Workflow.version.desc())
|
||||
.limit(limit + 1)
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
|
||||
workflows = list(session.scalars(stmt).all())
|
||||
has_more = len(workflows) > limit
|
||||
if has_more:
|
||||
workflows = workflows[:-1]
|
||||
|
||||
return workflows, has_more
|
||||
|
||||
# --- Default Block Configs ---
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configurations for all node types.
|
||||
|
||||
:return: List of default configurations
|
||||
"""
|
||||
default_block_configs: list[dict[str, Any]] = []
|
||||
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(dict(default_config))
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
|
||||
"""
|
||||
Get default config for specific node type.
|
||||
|
||||
:param node_type: Node type string
|
||||
:param filters: Optional filters
|
||||
:return: Default configuration or None
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
|
||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||
return None
|
||||
|
||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
# --- Workflow Run Operations ---
|
||||
|
||||
def get_snippet_workflow_runs(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
args: dict,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param args: Request arguments (last_id, limit)
|
||||
:return: InfiniteScrollPagination result
|
||||
"""
|
||||
limit = int(args.get("limit", 20))
|
||||
last_id = args.get("last_id")
|
||||
|
||||
triggered_from_values = [
|
||||
WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
]
|
||||
|
||||
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
triggered_from=triggered_from_values,
|
||||
limit=limit,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
def get_snippet_workflow_run(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
run_id: str,
|
||||
) -> WorkflowRun | None:
|
||||
"""
|
||||
Get workflow run details.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param run_id: Workflow run ID
|
||||
:return: WorkflowRun or None
|
||||
"""
|
||||
return self._workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
def get_snippet_workflow_run_node_executions(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get workflow run node execution list.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param run_id: Workflow run ID
|
||||
:return: List of WorkflowNodeExecutionModel
|
||||
"""
|
||||
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
|
||||
return node_executions
|
||||
|
||||
# --- Node Execution Operations ---
|
||||
|
||||
def get_snippet_node_last_run(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecutionModel | None:
|
||||
"""
|
||||
Get the most recent execution for a specific node in a snippet workflow.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: Node identifier
|
||||
:return: WorkflowNodeExecutionModel or None
|
||||
"""
|
||||
return self._node_execution_service_repo.get_node_last_execution(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
workflow_id=workflow.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# --- Use Count ---
|
||||
|
||||
@staticmethod
|
||||
def increment_use_count(
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
) -> None:
|
||||
"""
|
||||
Increment the use_count when snippet is used.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: CustomizedSnippet instance
|
||||
"""
|
||||
snippet.use_count += 1
|
||||
session.add(snippet)
|
||||
@ -38,6 +38,7 @@ from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
@ -819,9 +820,9 @@ class WebhookService:
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# consume quota before triggering workflow execution
|
||||
# reserve quota before triggering workflow execution
|
||||
try:
|
||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
@ -832,11 +833,16 @@ class WebhookService:
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
try:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@ -23,17 +23,28 @@ class LogView:
|
||||
"""Lightweight wrapper for WorkflowAppLog with computed details.
|
||||
|
||||
- Exposes `details_` for marshalling to `details` in API response
|
||||
- Exposes `evaluation_` for marshalling evaluation metrics in API response
|
||||
- Proxies all other attributes to the underlying `WorkflowAppLog`
|
||||
"""
|
||||
|
||||
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
|
||||
def __init__(
|
||||
self,
|
||||
log: WorkflowAppLog,
|
||||
details: LogViewDetails | None,
|
||||
evaluation: list[dict] | None = None,
|
||||
):
|
||||
self.log = log
|
||||
self.details_ = details
|
||||
self.evaluation_ = evaluation
|
||||
|
||||
@property
|
||||
def details(self) -> LogViewDetails | None:
|
||||
return self.details_
|
||||
|
||||
@property
|
||||
def evaluation(self) -> list[dict] | None:
|
||||
return self.evaluation_
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.log, name)
|
||||
|
||||
@ -170,12 +181,20 @@ class WorkflowAppService:
|
||||
# Execute query and get items
|
||||
if detail:
|
||||
rows = session.execute(offset_stmt).all()
|
||||
items = [
|
||||
LogView(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
|
||||
logs_with_details = [
|
||||
(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
|
||||
for log, meta_val in rows
|
||||
]
|
||||
else:
|
||||
items = [LogView(log, None) for log in session.scalars(offset_stmt).all()]
|
||||
logs_with_details = [(log, None) for log in session.scalars(offset_stmt).all()]
|
||||
|
||||
workflow_run_ids = [log.workflow_run_id for log, _ in logs_with_details]
|
||||
eval_map = self._batch_query_evaluation_metrics(session, workflow_run_ids)
|
||||
|
||||
items = [
|
||||
LogView(log, details, evaluation=eval_map.get(log.workflow_run_id))
|
||||
for log, details in logs_with_details
|
||||
]
|
||||
return {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
@ -257,6 +276,45 @@ class WorkflowAppService:
|
||||
"data": items,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _batch_query_evaluation_metrics(
|
||||
session: Session,
|
||||
workflow_run_ids: list[str],
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Return evaluation metrics keyed by workflow_run_id.
|
||||
|
||||
Only returns metrics from completed evaluation runs. If a workflow
|
||||
run was not part of any evaluation (or the evaluation has not
|
||||
completed), it will be absent from the result dict.
|
||||
"""
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
|
||||
|
||||
if not workflow_run_ids:
|
||||
return {}
|
||||
|
||||
non_null_ids = [wid for wid in workflow_run_ids if wid]
|
||||
if not non_null_ids:
|
||||
return {}
|
||||
|
||||
stmt = (
|
||||
select(EvaluationRunItem.workflow_run_id, EvaluationRunItem.metrics)
|
||||
.join(EvaluationRun, EvaluationRun.id == EvaluationRunItem.evaluation_run_id)
|
||||
.where(
|
||||
EvaluationRunItem.workflow_run_id.in_(non_null_ids),
|
||||
EvaluationRun.status == EvaluationRunStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
result: dict[str, list[dict[str, Any]]] = {}
|
||||
for wf_run_id, metrics_json in rows:
|
||||
if wf_run_id and metrics_json:
|
||||
parsed: list[dict[str, Any]] = json.loads(metrics_json)
|
||||
existing = result.get(wf_run_id, [])
|
||||
existing.extend(parsed)
|
||||
result[wf_run_id] = existing
|
||||
return result
|
||||
|
||||
def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
|
||||
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
|
||||
if not metadata:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@ -271,12 +271,20 @@ class WorkflowDraftVariableService:
|
||||
)
|
||||
|
||||
def list_variables_without_values(
|
||||
self, app_id: str, page: int, limit: int, user_id: str
|
||||
self,
|
||||
app_id: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
user_id: str,
|
||||
*,
|
||||
exclude_node_ids: Set[str] | None = None,
|
||||
) -> WorkflowDraftVariableList:
|
||||
criteria = [
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
if exclude_node_ids:
|
||||
criteria.append(WorkflowDraftVariable.node_id.notin_(list(exclude_node_ids)))
|
||||
total = None
|
||||
base_stmt = select(WorkflowDraftVariable).where(*criteria)
|
||||
if page == 1:
|
||||
|
||||
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import (
|
||||
@ -30,7 +30,7 @@ from graphon.variable_loader import load_into_variable_pool
|
||||
from graphon.variables import VariableBase
|
||||
from graphon.variables.input_entities import VariableEntityType
|
||||
from graphon.variables.variables import Variable
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import and_, exists, or_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -42,7 +42,7 @@ from core.entities import PluginCredentialType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.trigger.constants import is_trigger_node_type
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type
|
||||
from core.workflow.human_input_compat import (
|
||||
DeliveryChannelConfig,
|
||||
normalize_human_input_node_data_for_graph,
|
||||
@ -65,6 +65,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import escape_like_pattern
|
||||
from models import Account
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
from models.model import App, AppMode
|
||||
@ -99,6 +100,15 @@ class WorkflowService:
|
||||
Workflow Service
|
||||
"""
|
||||
|
||||
# Centralized unsupported node types for evaluation workflow publishing.
|
||||
# Keep this set updated when evaluation workflow constraints change.
|
||||
EVALUATION_UNSUPPORTED_NODE_TYPES: frozenset[str] = frozenset(
|
||||
{
|
||||
BuiltinNodeTypes.HUMAN_INPUT,
|
||||
*TRIGGER_NODE_TYPES,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, session_maker: sessionmaker | None = None):
|
||||
"""Initialize WorkflowService with repository dependencies."""
|
||||
if session_maker is None:
|
||||
@ -237,6 +247,59 @@ class WorkflowService:
|
||||
|
||||
return workflows, has_more
|
||||
|
||||
def list_published_evaluation_workflows(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
user_id: str | None,
|
||||
named_only: bool = False,
|
||||
keyword: str | None = None,
|
||||
) -> tuple[Sequence[Workflow], bool]:
|
||||
"""
|
||||
List published evaluation-type workflows for a tenant (cross-app), excluding draft rows.
|
||||
|
||||
When ``keyword`` is non-empty, match workflows whose marked name or parent app name contains
|
||||
the substring (case-insensitive, LIKE wildcards escaped).
|
||||
"""
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.type == WorkflowType.EVALUATION,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Workflow.created_by == user_id)
|
||||
|
||||
if named_only:
|
||||
stmt = stmt.where(Workflow.marked_name != "")
|
||||
|
||||
keyword_stripped = keyword.strip() if keyword else ""
|
||||
if keyword_stripped:
|
||||
escaped = escape_like_pattern(keyword_stripped)
|
||||
pattern = f"%{escaped}%"
|
||||
stmt = stmt.join(
|
||||
App,
|
||||
and_(Workflow.app_id == App.id, App.tenant_id == tenant_id),
|
||||
).where(
|
||||
or_(
|
||||
Workflow.marked_name.ilike(pattern, escape="\\"),
|
||||
App.name.ilike(pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit)
|
||||
|
||||
workflows = session.scalars(stmt).all()
|
||||
|
||||
has_more = len(workflows) > limit
|
||||
if has_more:
|
||||
workflows = workflows[:-1]
|
||||
|
||||
return workflows, has_more
|
||||
|
||||
def sync_draft_workflow(
|
||||
self,
|
||||
*,
|
||||
@ -400,6 +463,127 @@ class WorkflowService:
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
def publish_evaluation_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Workflow:
|
||||
"""Publish draft workflow as an evaluation workflow version.
|
||||
|
||||
Compared to standard publish:
|
||||
- force published workflow type to ``evaluation``;
|
||||
- reject graphs containing trigger or human-input nodes.
|
||||
"""
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# Validate credentials before publishing, for credential policy check
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if FeatureService.get_system_features().plugin_manager.enabled:
|
||||
self._validate_workflow_credentials(draft_workflow)
|
||||
|
||||
# validate graph structure
|
||||
self.validate_graph_structure(graph=draft_workflow.graph_dict)
|
||||
self._validate_evaluation_workflow_nodes(draft_workflow)
|
||||
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.EVALUATION.value,
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=draft_workflow.graph,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||
features=draft_workflow.features,
|
||||
)
|
||||
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
|
||||
return workflow
|
||||
|
||||
def convert_published_workflow_type(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
target_type: WorkflowType,
|
||||
account: Account,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Convert a published workflow type in-place.
|
||||
|
||||
This endpoint only supports conversion between standard workflow and evaluation workflow.
|
||||
"""
|
||||
if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}:
|
||||
raise ValueError("target_type must be either 'workflow' or 'evaluation'")
|
||||
|
||||
if not app_model.workflow_id:
|
||||
raise WorkflowNotFoundError("Published workflow not found")
|
||||
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
workflow = session.scalar(stmt)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError("Published workflow not found")
|
||||
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.")
|
||||
|
||||
if workflow.type == target_type:
|
||||
return workflow
|
||||
|
||||
if target_type == WorkflowType.EVALUATION:
|
||||
self._validate_evaluation_workflow_nodes(workflow)
|
||||
|
||||
workflow.type = target_type
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def _validate_evaluation_workflow_nodes(workflow: Workflow) -> None:
|
||||
"""Ensure evaluation workflows do not contain unsupported node types."""
|
||||
disallowed_nodes: list[tuple[str, str]] = []
|
||||
for node_id, node_data in workflow.walk_nodes():
|
||||
node_type = node_data.get("type")
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if node_type in WorkflowService.EVALUATION_UNSUPPORTED_NODE_TYPES:
|
||||
disallowed_nodes.append((node_id, node_type))
|
||||
|
||||
if not disallowed_nodes:
|
||||
return
|
||||
|
||||
formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes)
|
||||
raise ValueError(
|
||||
"Evaluation workflow cannot contain trigger or human-input nodes. "
|
||||
f"Found disallowed nodes: {formatted_nodes}"
|
||||
)
|
||||
|
||||
def _validate_workflow_credentials(self, workflow: Workflow) -> None:
|
||||
"""
|
||||
Validate all credentials in workflow nodes before publishing.
|
||||
@ -1550,8 +1734,8 @@ def _setup_variable_pool(
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Only add chatflow-specific variables for non-workflow types.
|
||||
if workflow.type != WorkflowType.WORKFLOW:
|
||||
# Only add chatflow-specific variables for chat-like workflow types.
|
||||
if workflow.type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}:
|
||||
system_variable_values.update(
|
||||
{
|
||||
"query": query,
|
||||
|
||||
541
api/tasks/evaluation_task.py
Normal file
541
api/tasks/evaluation_task.py
Normal file
@ -0,0 +1,541 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
|
||||
from configs import dify_config
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationCategory,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemResult,
|
||||
EvaluationRunData,
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig
|
||||
from core.evaluation.evaluation_manager import EvaluationManager
|
||||
from core.evaluation.judgment.processor import JudgmentProcessor
|
||||
from core.evaluation.runners.agent_evaluation_runner import AgentEvaluationRunner
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from core.evaluation.runners.llm_evaluation_runner import LLMEvaluationRunner
|
||||
from core.evaluation.runners.retrieval_evaluation_runner import RetrievalEvaluationRunner
|
||||
from core.evaluation.runners.snippet_evaluation_runner import SnippetEvaluationRunner
|
||||
from core.evaluation.runners.workflow_evaluation_runner import WorkflowEvaluationRunner
|
||||
from extensions.ext_database import db
|
||||
from graphon.node_events import NodeRunResult
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
|
||||
from models.model import UploadFile
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="evaluation")
|
||||
def run_evaluation(run_data_dict: dict[str, Any]) -> None:
|
||||
"""Celery task for running evaluations asynchronously.
|
||||
|
||||
Workflow:
|
||||
1. Deserialize EvaluationRunData
|
||||
2. Execute target and collect node results
|
||||
3. Evaluate metrics via runners (one per metric-node pair)
|
||||
4. Merge results per test-data row (1 item = 1 EvaluationRunItem)
|
||||
5. Apply judgment conditions
|
||||
6. Persist results + generate result XLSX
|
||||
7. Update EvaluationRun status to COMPLETED
|
||||
"""
|
||||
run_data = EvaluationRunData.model_validate(run_data_dict)
|
||||
|
||||
with db.engine.connect() as connection:
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
session = Session(bind=connection)
|
||||
|
||||
try:
|
||||
_execute_evaluation(session, run_data)
|
||||
except Exception as e:
|
||||
logger.exception("Evaluation run %s failed", run_data.evaluation_run_id)
|
||||
_mark_run_failed(session, run_data.evaluation_run_id, str(e))
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def _execute_evaluation(session: Any, run_data: EvaluationRunData) -> None:
|
||||
"""Core evaluation execution logic."""
|
||||
evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first()
|
||||
if not evaluation_run:
|
||||
logger.error("EvaluationRun %s not found", run_data.evaluation_run_id)
|
||||
return
|
||||
|
||||
if evaluation_run.status == EvaluationRunStatus.CANCELLED:
|
||||
logger.info("EvaluationRun %s was cancelled", run_data.evaluation_run_id)
|
||||
return
|
||||
|
||||
evaluation_instance = EvaluationManager.get_evaluation_instance()
|
||||
if evaluation_instance is None:
|
||||
raise ValueError("Evaluation framework not configured")
|
||||
|
||||
# Mark as running
|
||||
evaluation_run.status = EvaluationRunStatus.RUNNING
|
||||
evaluation_run.started_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
if run_data.target_type == "dataset":
|
||||
results: list[EvaluationItemResult] = _execute_retrieval_test(
|
||||
session=session,
|
||||
evaluation_run=evaluation_run,
|
||||
run_data=run_data,
|
||||
evaluation_instance=evaluation_instance,
|
||||
)
|
||||
else:
|
||||
evaluation_service = EvaluationService()
|
||||
node_run_result_mapping_list, workflow_run_ids = evaluation_service.execute_targets(
|
||||
tenant_id=run_data.tenant_id,
|
||||
target_type=run_data.target_type,
|
||||
target_id=run_data.target_id,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
|
||||
workflow_run_id_map = {
|
||||
item.index: wf_run_id
|
||||
for item, wf_run_id in zip(run_data.input_list, workflow_run_ids)
|
||||
if wf_run_id
|
||||
}
|
||||
|
||||
results = _execute_evaluation_runner(
|
||||
session=session,
|
||||
run_data=run_data,
|
||||
evaluation_instance=evaluation_instance,
|
||||
node_run_result_mapping_list=node_run_result_mapping_list,
|
||||
workflow_run_id_map=workflow_run_id_map,
|
||||
)
|
||||
|
||||
# Compute summary metrics
|
||||
metrics_summary = _compute_metrics_summary(results, run_data.judgment_config)
|
||||
|
||||
# Generate result XLSX
|
||||
result_xlsx = _generate_result_xlsx(run_data.input_list, results)
|
||||
|
||||
# Store result file
|
||||
result_file_id = _store_result_file(run_data.tenant_id, run_data.evaluation_run_id, result_xlsx, session)
|
||||
|
||||
# Update run to completed
|
||||
evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first()
|
||||
if evaluation_run:
|
||||
evaluation_run.status = EvaluationRunStatus.COMPLETED
|
||||
evaluation_run.completed_at = naive_utc_now()
|
||||
evaluation_run.metrics_summary = json.dumps(metrics_summary)
|
||||
if result_file_id:
|
||||
evaluation_run.result_file_id = result_file_id
|
||||
session.commit()
|
||||
|
||||
logger.info("Evaluation run %s completed successfully", run_data.evaluation_run_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation orchestration — merge + judgment + persist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_evaluation_runner(
|
||||
session: Any,
|
||||
run_data: EvaluationRunData,
|
||||
evaluation_instance: BaseEvaluationInstance,
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
|
||||
workflow_run_id_map: dict[int, str] | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate all metrics, merge per-item, apply judgment, persist once.
|
||||
|
||||
Ensures 1 test-data row = 1 EvaluationRunItem with all metrics combined.
|
||||
"""
|
||||
results_by_index: dict[int, EvaluationItemResult] = {}
|
||||
|
||||
# Phase 1: Default metrics — one batch per (metric, node) pair
|
||||
for default_metric in run_data.default_metrics:
|
||||
for node_info in default_metric.node_info_list:
|
||||
node_run_result_list: list[NodeRunResult] = []
|
||||
item_indices: list[int] = []
|
||||
for i, mapping in enumerate(node_run_result_mapping_list):
|
||||
node_result = mapping.get(node_info.node_id)
|
||||
if node_result is not None:
|
||||
node_run_result_list.append(node_result)
|
||||
item_indices.append(i)
|
||||
|
||||
if not node_run_result_list:
|
||||
continue
|
||||
|
||||
runner = _create_runner(EvaluationCategory(node_info.type), evaluation_instance)
|
||||
try:
|
||||
evaluated = runner.evaluate_metrics(
|
||||
node_run_result_list=node_run_result_list,
|
||||
default_metric=default_metric,
|
||||
model_provider=run_data.evaluation_model_provider,
|
||||
model_name=run_data.evaluation_model,
|
||||
tenant_id=run_data.tenant_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed metrics for %s on node %s", default_metric.metric, node_info.node_id
|
||||
)
|
||||
continue
|
||||
|
||||
_stamp_and_merge(evaluated, item_indices, node_info, results_by_index)
|
||||
|
||||
# Phase 2: Customized metrics
|
||||
if run_data.customized_metrics:
|
||||
try:
|
||||
customized_results = evaluation_instance.evaluate_with_customized_workflow(
|
||||
node_run_result_mapping_list=node_run_result_mapping_list,
|
||||
customized_metrics=run_data.customized_metrics,
|
||||
tenant_id=run_data.tenant_id,
|
||||
)
|
||||
for result in customized_results:
|
||||
_merge_result(results_by_index, result.index, result)
|
||||
except Exception:
|
||||
logger.exception("Failed customized metrics for run %s", run_data.evaluation_run_id)
|
||||
|
||||
results = list(results_by_index.values())
|
||||
|
||||
# Phase 3: Judgment
|
||||
if run_data.judgment_config:
|
||||
results = _apply_judgment(results, run_data.judgment_config)
|
||||
|
||||
# Phase 4: Persist — one EvaluationRunItem per test-data row
|
||||
_persist_results(
|
||||
session, run_data.evaluation_run_id, results, run_data.input_list, workflow_run_id_map
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _execute_retrieval_test(
|
||||
session: Any,
|
||||
evaluation_run: EvaluationRun,
|
||||
run_data: EvaluationRunData,
|
||||
evaluation_instance: BaseEvaluationInstance,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Execute knowledge base retrieval for all items, then evaluate metrics.
|
||||
|
||||
Unlike the workflow-based path, there are no workflow nodes to traverse.
|
||||
Hit testing is run directly for each dataset item and the results are fed
|
||||
straight into :class:`RetrievalEvaluationRunner`.
|
||||
"""
|
||||
node_run_result_list = EvaluationService.execute_retrieval_test_targets(
|
||||
dataset_id=run_data.target_id,
|
||||
account_id=evaluation_run.created_by,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
|
||||
results_by_index: dict[int, EvaluationItemResult] = {}
|
||||
runner = RetrievalEvaluationRunner(evaluation_instance)
|
||||
|
||||
for default_metric in run_data.default_metrics:
|
||||
try:
|
||||
evaluated = runner.evaluate_metrics(
|
||||
node_run_result_list=node_run_result_list,
|
||||
default_metric=default_metric,
|
||||
model_provider=run_data.evaluation_model_provider,
|
||||
model_name=run_data.evaluation_model,
|
||||
tenant_id=run_data.tenant_id,
|
||||
)
|
||||
item_indices = list(range(len(node_run_result_list)))
|
||||
_stamp_and_merge(evaluated, item_indices, None, results_by_index)
|
||||
except Exception:
|
||||
logger.exception("Failed retrieval metrics for run %s", run_data.evaluation_run_id)
|
||||
|
||||
results = list(results_by_index.values())
|
||||
|
||||
if run_data.judgment_config:
|
||||
results = _apply_judgment(results, run_data.judgment_config)
|
||||
|
||||
_persist_results(session, run_data.evaluation_run_id, results, run_data.input_list)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — merge, judgment, persist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _stamp_and_merge(
|
||||
evaluated: list[EvaluationItemResult],
|
||||
item_indices: list[int],
|
||||
node_info: NodeInfo | None,
|
||||
results_by_index: dict[int, EvaluationItemResult],
|
||||
) -> None:
|
||||
"""Attach node_info to each metric and merge into results_by_index."""
|
||||
for result in evaluated:
|
||||
original_index = item_indices[result.index]
|
||||
if node_info is not None:
|
||||
for metric in result.metrics:
|
||||
metric.node_info = node_info
|
||||
_merge_result(results_by_index, original_index, result)
|
||||
|
||||
|
||||
def _merge_result(
|
||||
results_by_index: dict[int, EvaluationItemResult],
|
||||
index: int,
|
||||
new_result: EvaluationItemResult,
|
||||
) -> None:
|
||||
"""Merge new metrics into an existing result for the same index."""
|
||||
existing = results_by_index.get(index)
|
||||
if existing:
|
||||
merged_metrics = existing.metrics + new_result.metrics
|
||||
actual = existing.actual_output or new_result.actual_output
|
||||
results_by_index[index] = existing.model_copy(
|
||||
update={"metrics": merged_metrics, "actual_output": actual}
|
||||
)
|
||||
else:
|
||||
results_by_index[index] = new_result.model_copy(update={"index": index})
|
||||
|
||||
|
||||
def _apply_judgment(
|
||||
results: list[EvaluationItemResult],
|
||||
judgment_config: JudgmentConfig,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Evaluate pass/fail judgment conditions on each result's metrics."""
|
||||
judged: list[EvaluationItemResult] = []
|
||||
for result in results:
|
||||
if result.error is not None or not result.metrics:
|
||||
judged.append(result)
|
||||
continue
|
||||
metric_values: dict[tuple[str, str], object] = {
|
||||
(m.node_info.node_id, m.name): m.value for m in result.metrics if m.node_info
|
||||
}
|
||||
judgment_result = JudgmentProcessor.evaluate(metric_values, judgment_config)
|
||||
judged.append(result.model_copy(update={"judgment": judgment_result}))
|
||||
return judged
|
||||
|
||||
|
||||
def _persist_results(
|
||||
session: Any,
|
||||
evaluation_run_id: str,
|
||||
results: list[EvaluationItemResult],
|
||||
input_list: list[EvaluationDatasetInput],
|
||||
workflow_run_id_map: dict[int, str] | None = None,
|
||||
) -> None:
|
||||
"""Persist evaluation results — one EvaluationRunItem per test-data row."""
|
||||
dataset_map = {item.index: item for item in input_list}
|
||||
wf_map = workflow_run_id_map or {}
|
||||
|
||||
for result in results:
|
||||
item_input = dataset_map.get(result.index)
|
||||
run_item = EvaluationRunItem(
|
||||
evaluation_run_id=evaluation_run_id,
|
||||
workflow_run_id=wf_map.get(result.index),
|
||||
item_index=result.index,
|
||||
inputs=json.dumps(item_input.inputs) if item_input else None,
|
||||
expected_output=item_input.expected_output if item_input else None,
|
||||
actual_output=result.actual_output,
|
||||
metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None,
|
||||
judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None,
|
||||
metadata_json=json.dumps(result.metadata) if result.metadata else None,
|
||||
error=result.error,
|
||||
overall_score=getattr(result, "overall_score", None),
|
||||
)
|
||||
session.add(run_item)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def _create_runner(
|
||||
category: EvaluationCategory,
|
||||
evaluation_instance: BaseEvaluationInstance,
|
||||
) -> BaseEvaluationRunner:
|
||||
"""Create the appropriate runner for the evaluation category."""
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
return LLMEvaluationRunner(evaluation_instance)
|
||||
case EvaluationCategory.RETRIEVAL | EvaluationCategory.KNOWLEDGE_BASE:
|
||||
return RetrievalEvaluationRunner(evaluation_instance)
|
||||
case EvaluationCategory.AGENT:
|
||||
return AgentEvaluationRunner(evaluation_instance)
|
||||
case EvaluationCategory.WORKFLOW:
|
||||
return WorkflowEvaluationRunner(evaluation_instance)
|
||||
case EvaluationCategory.SNIPPET:
|
||||
return SnippetEvaluationRunner(evaluation_instance)
|
||||
case _:
|
||||
raise ValueError(f"Unknown evaluation category: {category}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status / summary / XLSX / storage helpers (unchanged logic)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mark_run_failed(session: Any, run_id: str, error: str) -> None:
|
||||
"""Mark an evaluation run as failed."""
|
||||
try:
|
||||
evaluation_run = session.query(EvaluationRun).filter_by(id=run_id).first()
|
||||
if evaluation_run:
|
||||
evaluation_run.status = EvaluationRunStatus.FAILED
|
||||
evaluation_run.error = error[:2000]
|
||||
evaluation_run.completed_at = naive_utc_now()
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to mark run %s as failed", run_id)
|
||||
|
||||
|
||||
def _compute_metrics_summary(
|
||||
results: list[EvaluationItemResult],
|
||||
judgment_config: JudgmentConfig | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute aggregate metric and judgment summaries for an evaluation run."""
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
if judgment_config is not None and judgment_config.conditions:
|
||||
evaluated_results: list[EvaluationItemResult] = [
|
||||
result for result in results if result.error is None and result.metrics
|
||||
]
|
||||
passed_items = sum(1 for result in evaluated_results if result.judgment.passed)
|
||||
evaluated_items = len(evaluated_results)
|
||||
summary["_judgment"] = {
|
||||
"enabled": True,
|
||||
"logical_operator": judgment_config.logical_operator,
|
||||
"configured_conditions": len(judgment_config.conditions),
|
||||
"evaluated_items": evaluated_items,
|
||||
"passed_items": passed_items,
|
||||
"failed_items": evaluated_items - passed_items,
|
||||
"pass_rate": passed_items / evaluated_items if evaluated_items else 0.0,
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _generate_result_xlsx(
|
||||
input_list: list[EvaluationDatasetInput],
|
||||
results: list[EvaluationItemResult],
|
||||
) -> bytes:
|
||||
"""Generate result XLSX with input data, actual output, metric scores, and judgment."""
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
if ws is None:
|
||||
ws = wb.create_sheet("Evaluation Results")
|
||||
ws.title = "Evaluation Results"
|
||||
|
||||
header_font = Font(bold=True, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
|
||||
header_alignment = Alignment(horizontal="center", vertical="center")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin"),
|
||||
right=Side(style="thin"),
|
||||
top=Side(style="thin"),
|
||||
bottom=Side(style="thin"),
|
||||
)
|
||||
|
||||
# Collect all metric names
|
||||
all_metric_names: list[str] = []
|
||||
for result in results:
|
||||
for metric in result.metrics:
|
||||
if metric.name not in all_metric_names:
|
||||
all_metric_names.append(metric.name)
|
||||
|
||||
# Collect all input keys
|
||||
input_keys: list[str] = []
|
||||
for item in input_list:
|
||||
for key in item.inputs:
|
||||
if key not in input_keys:
|
||||
input_keys.append(key)
|
||||
|
||||
has_judgment = any(bool(r.judgment.condition_results) for r in results)
|
||||
|
||||
judgment_headers = ["judgment"] if has_judgment else []
|
||||
headers = (
|
||||
["index"] + input_keys + ["expected_output", "actual_output"] + all_metric_names + judgment_headers + ["error"]
|
||||
)
|
||||
|
||||
for col_idx, header in enumerate(headers, start=1):
|
||||
cell = ws.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = header_alignment
|
||||
cell.border = thin_border
|
||||
|
||||
ws.column_dimensions["A"].width = 10
|
||||
for col_idx in range(2, len(headers) + 1):
|
||||
ws.column_dimensions[get_column_letter(col_idx)].width = 25
|
||||
|
||||
result_by_index = {r.index: r for r in results}
|
||||
|
||||
for row_idx, item in enumerate(input_list, start=2):
|
||||
result = result_by_index.get(item.index)
|
||||
|
||||
col = 1
|
||||
ws.cell(row=row_idx, column=col, value=item.index).border = thin_border
|
||||
col += 1
|
||||
|
||||
for key in input_keys:
|
||||
val = item.inputs.get(key, "")
|
||||
ws.cell(row=row_idx, column=col, value=str(val)).border = thin_border
|
||||
col += 1
|
||||
|
||||
ws.cell(row=row_idx, column=col, value=item.expected_output or "").border = thin_border
|
||||
col += 1
|
||||
|
||||
ws.cell(row=row_idx, column=col, value=result.actual_output if result else "").border = thin_border
|
||||
col += 1
|
||||
|
||||
metric_scores = {m.name: m.value for m in result.metrics} if result else {}
|
||||
for metric_name in all_metric_names:
|
||||
score = metric_scores.get(metric_name)
|
||||
ws.cell(row=row_idx, column=col, value=score if score is not None else "").border = thin_border
|
||||
col += 1
|
||||
|
||||
if has_judgment:
|
||||
if result and result.judgment.condition_results:
|
||||
judgment_value = "Pass" if result.judgment.passed else "Fail"
|
||||
else:
|
||||
judgment_value = ""
|
||||
ws.cell(row=row_idx, column=col, value=judgment_value).border = thin_border
|
||||
col += 1
|
||||
|
||||
ws.cell(row=row_idx, column=col, value=result.error if result else "").border = thin_border
|
||||
|
||||
output = io.BytesIO()
|
||||
wb.save(output)
|
||||
output.seek(0)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def _store_result_file(
|
||||
tenant_id: str,
|
||||
run_id: str,
|
||||
xlsx_content: bytes,
|
||||
session: Any,
|
||||
) -> str | None:
|
||||
"""Store result XLSX file and return the UploadFile ID."""
|
||||
try:
|
||||
from extensions.ext_storage import storage
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
filename = f"evaluation-result-{run_id[:8]}.xlsx"
|
||||
storage_key = f"evaluation_results/{tenant_id}/{str(uuidv7())}.xlsx"
|
||||
|
||||
storage.save(storage_key, xlsx_content)
|
||||
|
||||
upload_file: UploadFile = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=storage_key,
|
||||
name=filename,
|
||||
size=len(xlsx_content),
|
||||
extension="xlsx",
|
||||
mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="system",
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(upload_file)
|
||||
session.commit()
|
||||
return upload_file.id
|
||||
except Exception:
|
||||
logger.exception("Failed to store result file for run %s", run_id)
|
||||
return None
|
||||
@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
CreatorUserRole,
|
||||
@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
@ -298,10 +299,10 @@ def dispatch_triggered_workflow(
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
# consume quota before invoking trigger
|
||||
# reserve quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
@ -387,6 +388,7 @@ def dispatch_triggered_workflow(
|
||||
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
|
||||
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
|
||||
quota_charge.commit()
|
||||
dispatched_count += 1
|
||||
logger.info(
|
||||
"Triggered workflow for app %s with trigger event %s",
|
||||
|
||||
@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
|
||||
@ -36,12 +36,19 @@ class TestAppGenerateService:
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
"result": "success",
|
||||
"history_id": "test_history_id",
|
||||
mock_billing_service.quota_reserve.return_value = {
|
||||
"reservation_id": "test-reservation-id",
|
||||
"available": 100,
|
||||
"reserved": 1,
|
||||
}
|
||||
mock_billing_service.quota_commit.return_value = {
|
||||
"available": 99,
|
||||
"reserved": 0,
|
||||
"refunded": 0,
|
||||
}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
@ -101,6 +108,8 @@ class TestAppGenerateService:
|
||||
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
mock_quota_dify_config.BILLING_ENABLED = False
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
@ -118,6 +127,7 @@ class TestAppGenerateService:
|
||||
"message_based_generator": mock_message_based_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"quota_dify_config": mock_quota_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
@ -465,6 +475,7 @@ class TestAppGenerateService:
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
|
||||
|
||||
# Setup test arguments
|
||||
@ -478,8 +489,10 @@ class TestAppGenerateService:
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
# Verify billing two-phase quota (reserve + commit)
|
||||
billing = mock_external_service_dependencies["billing_service"]
|
||||
billing.quota_reserve.assert_called_once()
|
||||
billing.quota_commit.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from enums import quota_type
|
||||
from services import quota_service
|
||||
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
145
api/tests/unit_tests/core/evaluation/judgment/test_processor.py
Normal file
145
api/tests/unit_tests/core/evaluation/judgment/test_processor.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Unit tests for metric-based judgment evaluation."""
|
||||
|
||||
from core.evaluation.entities.judgment_entity import JudgmentCondition, JudgmentConfig
|
||||
from core.evaluation.judgment.processor import JudgmentProcessor
|
||||
|
||||
|
||||
def test_evaluate_uses_and_conditions_against_metric_values() -> None:
|
||||
"""All conditions must pass when the logical operator is ``and``."""
|
||||
config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_node_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
),
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_node_1", "answer_relevancy"],
|
||||
comparison_operator="≥",
|
||||
value="0.7",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = JudgmentProcessor.evaluate(
|
||||
{
|
||||
("llm_node_1", "faithfulness"): 0.9,
|
||||
("llm_node_1", "answer_relevancy"): 0.75,
|
||||
},
|
||||
config,
|
||||
)
|
||||
|
||||
assert result.passed is True
|
||||
assert len(result.condition_results) == 2
|
||||
assert all(condition_result.passed for condition_result in result.condition_results)
|
||||
|
||||
|
||||
def test_evaluate_sets_passed_false_when_any_and_condition_fails() -> None:
|
||||
"""A failed metric comparison should make the overall judgment fail."""
|
||||
config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_node_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
),
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_node_1", "answer_relevancy"],
|
||||
comparison_operator="≥",
|
||||
value="0.7",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = JudgmentProcessor.evaluate(
|
||||
{
|
||||
("llm_node_1", "faithfulness"): 0.9,
|
||||
("llm_node_1", "answer_relevancy"): 0.6,
|
||||
},
|
||||
config,
|
||||
)
|
||||
|
||||
assert result.passed is False
|
||||
assert result.condition_results[-1].passed is False
|
||||
|
||||
|
||||
def test_evaluate_with_different_nodes_same_metric() -> None:
|
||||
"""Conditions can target different nodes even with the same metric name."""
|
||||
config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_node_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
),
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_node_2", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.5",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = JudgmentProcessor.evaluate(
|
||||
{
|
||||
("llm_node_1", "faithfulness"): 0.9,
|
||||
("llm_node_2", "faithfulness"): 0.6,
|
||||
},
|
||||
config,
|
||||
)
|
||||
|
||||
assert result.passed is True
|
||||
assert len(result.condition_results) == 2
|
||||
|
||||
|
||||
def test_evaluate_or_operator_passes_when_one_condition_met() -> None:
|
||||
"""With ``or`` logical operator, one passing condition should suffice."""
|
||||
config = JudgmentConfig(
|
||||
logical_operator="or",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["node_a", "score"],
|
||||
comparison_operator=">",
|
||||
value="0.9",
|
||||
),
|
||||
JudgmentCondition(
|
||||
variable_selector=["node_b", "score"],
|
||||
comparison_operator=">",
|
||||
value="0.5",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = JudgmentProcessor.evaluate(
|
||||
{
|
||||
("node_a", "score"): 0.3,
|
||||
("node_b", "score"): 0.8,
|
||||
},
|
||||
config,
|
||||
)
|
||||
|
||||
assert result.passed is True
|
||||
|
||||
|
||||
def test_evaluate_string_contains_operator() -> None:
|
||||
"""String operators should work correctly via workflow engine delegation."""
|
||||
config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["node_a", "status"],
|
||||
comparison_operator="contains",
|
||||
value="success",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = JudgmentProcessor.evaluate(
|
||||
{("node_a", "status"): "evaluation_success_done"},
|
||||
config,
|
||||
)
|
||||
|
||||
assert result.passed is True
|
||||
@ -0,0 +1,78 @@
|
||||
"""Tests for judgment application logic (moved from BaseEvaluationRunner to evaluation_task)."""
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationItemResult, EvaluationMetric, NodeInfo
|
||||
from core.evaluation.entities.judgment_entity import JudgmentCondition, JudgmentConfig
|
||||
from tasks.evaluation_task import _apply_judgment
|
||||
|
||||
_NODE_INFO = NodeInfo(node_id="llm_1", type="llm", title="LLM Node")
|
||||
|
||||
|
||||
def test_apply_judgment_marks_passing_result() -> None:
|
||||
"""Items whose metrics satisfy the judgment conditions should be marked as passed."""
|
||||
results = [
|
||||
EvaluationItemResult(
|
||||
index=0,
|
||||
actual_output="result",
|
||||
metrics=[EvaluationMetric(name="faithfulness", value=0.91, node_info=_NODE_INFO)],
|
||||
)
|
||||
]
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
judged = _apply_judgment(results, judgment_config)
|
||||
|
||||
assert judged[0].judgment.passed is True
|
||||
|
||||
|
||||
def test_apply_judgment_marks_failing_result() -> None:
|
||||
"""Items whose metrics do NOT satisfy the conditions should be marked as failed."""
|
||||
results = [
|
||||
EvaluationItemResult(
|
||||
index=0,
|
||||
metrics=[EvaluationMetric(name="faithfulness", value=0.5, node_info=_NODE_INFO)],
|
||||
)
|
||||
]
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
judged = _apply_judgment(results, judgment_config)
|
||||
|
||||
assert judged[0].judgment.passed is False
|
||||
|
||||
|
||||
def test_apply_judgment_skips_errored_items() -> None:
|
||||
"""Items with errors should be passed through without judgment evaluation."""
|
||||
results = [
|
||||
EvaluationItemResult(index=0, error="timeout"),
|
||||
]
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
judged = _apply_judgment(results, judgment_config)
|
||||
|
||||
assert judged[0].error == "timeout"
|
||||
assert judged[0].judgment.passed is False
|
||||
0
api/tests/unit_tests/enums/__init__.py
Normal file
0
api/tests/unit_tests/enums/__init__.py
Normal file
349
api/tests/unit_tests/enums/test_quota_type.py
Normal file
349
api/tests/unit_tests/enums/test_quota_type.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.quota_type import QuotaType
|
||||
from services.quota_service import QuotaCharge, QuotaService, unlimited
|
||||
|
||||
|
||||
class TestQuotaType:
|
||||
def test_billing_key_trigger(self):
|
||||
assert QuotaType.TRIGGER.billing_key == "trigger_event"
|
||||
|
||||
def test_billing_key_workflow(self):
|
||||
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
|
||||
|
||||
def test_billing_key_unlimited_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid quota type"):
|
||||
_ = QuotaType.UNLIMITED.billing_key
|
||||
|
||||
|
||||
class TestQuotaService:
|
||||
def test_reserve_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService"),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_reserve_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_reserve_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
|
||||
|
||||
assert charge.success is True
|
||||
assert charge.charge_id == "rid-1"
|
||||
assert charge._tenant_id == "t1"
|
||||
assert charge._feature_key == "trigger_event"
|
||||
assert charge._amount == 1
|
||||
mock_bs.quota_reserve.assert_called_once()
|
||||
|
||||
def test_reserve_no_reservation_id_raises(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {}
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_quota_exceeded_propagates(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_api_exception_returns_unlimited(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = RuntimeError("network")
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_consume_calls_reserve_and_commit(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
mock_bs.quota_commit.assert_called_once()
|
||||
|
||||
def test_check_billing_disabled(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_check_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_check_sufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=100),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
|
||||
|
||||
def test_check_insufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=5),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
|
||||
|
||||
def test_check_unlimited_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=-1),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
|
||||
|
||||
def test_check_exception_returns_true(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_release_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_empty_reservation(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.return_value = {}
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_called_once_with(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
|
||||
)
|
||||
|
||||
def test_release_exception_swallowed(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.side_effect = RuntimeError("fail")
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_get_remaining_normal(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
|
||||
|
||||
def test_get_remaining_unlimited(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_over_limit_returns_zero(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_exception_returns_neg1(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.side_effect = RuntimeError
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_empty_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_non_dict_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = "invalid"
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_feature_not_in_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
|
||||
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_remaining_non_dict_feature_info(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
|
||||
class TestQuotaCharge:
|
||||
def test_commit_success(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
feature_key="trigger_event",
|
||||
reservation_id="rid-1",
|
||||
actual_amount=1,
|
||||
)
|
||||
assert charge._committed is True
|
||||
|
||||
def test_commit_with_actual_amount(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=10,
|
||||
)
|
||||
charge.commit(actual_amount=5)
|
||||
call_kwargs = mock_bs.quota_commit.call_args[1]
|
||||
assert call_kwargs["actual_amount"] == 5
|
||||
|
||||
def test_commit_idempotent(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
charge.commit()
|
||||
assert mock_bs.quota_commit.call_count == 1
|
||||
|
||||
def test_commit_no_charge_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_no_tenant_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_exception_swallowed(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.side_effect = RuntimeError("fail")
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
|
||||
def test_refund_success(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_refund_no_charge_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
def test_refund_no_tenant_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
|
||||
class TestUnlimited:
|
||||
def test_unlimited_returns_success_with_no_charge_id(self):
|
||||
charge = unlimited()
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
assert charge._quota_type == QuotaType.UNLIMITED
|
||||
@ -23,6 +23,7 @@ import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from enums.quota_type import QuotaType
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@ -447,8 +448,8 @@ class TestGenerateBilling:
|
||||
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
consume_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
reserve_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
@ -467,7 +468,8 @@ class TestGenerateBilling:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
consume_mock.assert_called_once_with("tenant-id")
|
||||
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
|
||||
quota_charge.commit.assert_called_once()
|
||||
|
||||
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
|
||||
from services.errors.app import QuotaExceededError
|
||||
@ -475,7 +477,7 @@ class TestGenerateBilling:
|
||||
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
|
||||
)
|
||||
|
||||
@ -492,7 +494,7 @@ class TestGenerateBilling:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
|
||||
@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
|
||||
- repo: SQLAlchemyWorkflowTriggerLogRepository
|
||||
- dispatcher_manager_class: QueueDispatcherManager class
|
||||
- dispatcher: dispatcher instance
|
||||
- quota_workflow: QuotaType.WORKFLOW
|
||||
- quota_service: QuotaService mock
|
||||
- get_workflow: AsyncWorkflowService._get_workflow method
|
||||
- professional_task: execute_workflow_professional
|
||||
- team_task: execute_workflow_team
|
||||
@ -72,7 +72,12 @@ class TestAsyncWorkflowService:
|
||||
mock_repo.create.side_effect = _create_side_effect
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_quota_service = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@ -88,8 +93,8 @@ class TestAsyncWorkflowService:
|
||||
) as mock_get_workflow,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"QuotaType",
|
||||
new=SimpleNamespace(WORKFLOW=quota_workflow),
|
||||
"QuotaService",
|
||||
new=mock_quota_service,
|
||||
),
|
||||
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
|
||||
@ -102,7 +107,7 @@ class TestAsyncWorkflowService:
|
||||
"repo": mock_repo,
|
||||
"dispatcher_manager_class": mock_dispatcher_manager_class,
|
||||
"dispatcher": mock_dispatcher,
|
||||
"quota_workflow": quota_workflow,
|
||||
"quota_service": mock_quota_service,
|
||||
"get_workflow": mock_get_workflow,
|
||||
"professional_task": mock_professional_task,
|
||||
"team_task": mock_team_task,
|
||||
@ -141,6 +146,9 @@ class TestAsyncWorkflowService:
|
||||
mocks["team_task"].delay.return_value = task_result
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
quota_charge_mock = MagicMock()
|
||||
mocks["quota_service"].reserve.return_value = quota_charge_mock
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
@ -158,7 +166,8 @@ class TestAsyncWorkflowService:
|
||||
assert result.status == "queued"
|
||||
assert result.queue == queue_name
|
||||
|
||||
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
|
||||
mocks["quota_service"].reserve.assert_called_once()
|
||||
quota_charge_mock.commit.assert_called_once()
|
||||
assert session.commit.call_count == 2
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
@ -245,7 +254,7 @@ class TestAsyncWorkflowService:
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
|
||||
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
|
||||
feature="workflow",
|
||||
tenant_id="tenant-123",
|
||||
required=1,
|
||||
|
||||
@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation:
|
||||
yield mock
|
||||
|
||||
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
|
||||
"""Test retrieval of tenant feature plan usage information."""
|
||||
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
|
||||
@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation:
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_get_quota_info(self, mock_send_request):
|
||||
"""Test retrieval of quota info from new endpoint."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
# Act
|
||||
result = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
|
||||
"""Test updating tenant feature usage with positive delta (adding credits)."""
|
||||
# Arrange
|
||||
@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation:
|
||||
)
|
||||
|
||||
|
||||
class TestBillingServiceQuotaOperations:
|
||||
"""Unit tests for quota reserve/commit/release operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_quota_reserve_success(self, mock_send_request):
|
||||
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/reserve",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
|
||||
)
|
||||
|
||||
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
|
||||
|
||||
assert result["available"] == 99
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["reserved"] == 1
|
||||
assert isinstance(result["reserved"], int)
|
||||
|
||||
def test_quota_reserve_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
|
||||
meta = {"source": "webhook"}
|
||||
|
||||
BillingService.quota_reserve(
|
||||
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"source": "webhook"}
|
||||
|
||||
def test_quota_commit_success(self, mock_send_request):
|
||||
expected = {"available": 98, "reserved": 0, "refunded": 0}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/commit",
|
||||
json={
|
||||
"tenant_id": "t1",
|
||||
"feature_key": "trigger_event",
|
||||
"reservation_id": "rid-1",
|
||||
"actual_amount": 1,
|
||||
},
|
||||
)
|
||||
|
||||
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
|
||||
)
|
||||
|
||||
assert result["available"] == 97
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["refunded"] == 1
|
||||
assert isinstance(result["refunded"], int)
|
||||
|
||||
def test_quota_commit_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
|
||||
meta = {"reason": "partial"}
|
||||
|
||||
BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"reason": "partial"}
|
||||
|
||||
def test_quota_release_success(self, mock_send_request):
|
||||
expected = {"available": 100, "reserved": 0, "released": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
|
||||
)
|
||||
|
||||
def test_quota_release_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
|
||||
|
||||
assert result["available"] == 100
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["released"] == 1
|
||||
assert isinstance(result["released"], int)
|
||||
|
||||
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
|
||||
mock_send_request.return_value = {
|
||||
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
|
||||
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
|
||||
}
|
||||
|
||||
result = BillingService.get_quota_info("t1")
|
||||
|
||||
assert result["trigger_event"]["usage"] == 42
|
||||
assert isinstance(result["trigger_event"]["usage"], int)
|
||||
assert result["trigger_event"]["limit"] == 3000
|
||||
assert isinstance(result["trigger_event"]["limit"], int)
|
||||
assert result["trigger_event"]["reset_date"] == 1700000000
|
||||
assert isinstance(result["trigger_event"]["reset_date"], int)
|
||||
assert result["api_rate_limit"]["limit"] == -1
|
||||
assert isinstance(result["api_rate_limit"]["limit"], int)
|
||||
|
||||
def test_get_quota_info_accepts_int_values(self, mock_send_request):
|
||||
"""Test that get_quota_info works with native int values."""
|
||||
expected = {
|
||||
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
|
||||
"api_rate_limit": {"usage": 0, "limit": -1},
|
||||
}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.get_quota_info("t1")
|
||||
|
||||
assert result["trigger_event"]["usage"] == 42
|
||||
assert result["trigger_event"]["limit"] == 3000
|
||||
assert result["api_rate_limit"]["limit"] == -1
|
||||
|
||||
|
||||
class TestBillingServiceRateLimitEnforcement:
|
||||
"""Unit tests for rate limit enforcement mechanisms.
|
||||
|
||||
|
||||
@ -559,3 +559,772 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
|
||||
|
||||
# === Merged from test_webhook_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1", is_debug=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
# Act / Assert
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
# Act
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
# Assert
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
# Act
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
# Assert
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
|
||||
# Arrange
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters() -> None:
|
||||
# Arrange
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
# Assert
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
|
||||
# Arrange
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
# Assert
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
|
||||
# Arrange
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
# Assert
|
||||
assert True
|
||||
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
|
||||
# Arrange
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
# Assert
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
|
||||
# Arrange
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
|
||||
# Act
|
||||
result = WebhookService._extract_content_type(headers)
|
||||
|
||||
# Assert
|
||||
assert result == "application/json"
|
||||
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
# Act
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Assert
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
# Act
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Assert
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.QuotaService,
|
||||
"reserve",
|
||||
MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)),
|
||||
)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
|
||||
# Arrange
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
# Act
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
# Assert
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
# Assert
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
|
||||
# Assert
|
||||
assert result == 123
|
||||
|
||||
|
||||
97
api/tests/unit_tests/tasks/test_evaluation_task.py
Normal file
97
api/tests/unit_tests/tasks/test_evaluation_task.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Unit tests for evaluation task helpers."""
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationItemResult, EvaluationMetric, NodeInfo
|
||||
from core.evaluation.entities.judgment_entity import (
|
||||
JudgmentCondition,
|
||||
JudgmentConfig,
|
||||
JudgmentResult,
|
||||
)
|
||||
from tasks.evaluation_task import _compute_metrics_summary, _merge_result, _stamp_and_merge
|
||||
|
||||
_NODE_INFO = NodeInfo(node_id="llm_1", type="llm", title="LLM Node")
|
||||
|
||||
|
||||
def test_compute_metrics_summary_includes_judgment_counts() -> None:
|
||||
"""Summary should expose pass/fail counts when judgment rules are configured."""
|
||||
judgment_config = JudgmentConfig(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
JudgmentCondition(
|
||||
variable_selector=["llm_1", "faithfulness"],
|
||||
comparison_operator=">",
|
||||
value="0.8",
|
||||
)
|
||||
],
|
||||
)
|
||||
results = [
|
||||
EvaluationItemResult(
|
||||
index=0,
|
||||
metrics=[EvaluationMetric(name="faithfulness", value=0.9, node_info=_NODE_INFO)],
|
||||
judgment=JudgmentResult(passed=True, logical_operator="and", condition_results=[]),
|
||||
),
|
||||
EvaluationItemResult(
|
||||
index=1,
|
||||
metrics=[EvaluationMetric(name="faithfulness", value=0.4, node_info=_NODE_INFO)],
|
||||
judgment=JudgmentResult(passed=False, logical_operator="and", condition_results=[]),
|
||||
),
|
||||
EvaluationItemResult(index=2, error="timeout"),
|
||||
]
|
||||
|
||||
summary = _compute_metrics_summary(results, judgment_config)
|
||||
|
||||
assert summary["_judgment"] == {
|
||||
"enabled": True,
|
||||
"logical_operator": "and",
|
||||
"configured_conditions": 1,
|
||||
"evaluated_items": 2,
|
||||
"passed_items": 1,
|
||||
"failed_items": 1,
|
||||
"pass_rate": 0.5,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_result_combines_metrics_for_same_index() -> None:
|
||||
"""Merging two results with the same index should concatenate their metrics."""
|
||||
results_by_index: dict[int, EvaluationItemResult] = {}
|
||||
|
||||
first = EvaluationItemResult(
|
||||
index=0,
|
||||
actual_output="output_1",
|
||||
metrics=[EvaluationMetric(name="faithfulness", value=0.9)],
|
||||
)
|
||||
_merge_result(results_by_index, 0, first)
|
||||
|
||||
second = EvaluationItemResult(
|
||||
index=0,
|
||||
actual_output="output_2",
|
||||
metrics=[EvaluationMetric(name="context_precision", value=0.7)],
|
||||
)
|
||||
_merge_result(results_by_index, 0, second)
|
||||
|
||||
merged = results_by_index[0]
|
||||
assert len(merged.metrics) == 2
|
||||
assert merged.metrics[0].name == "faithfulness"
|
||||
assert merged.metrics[1].name == "context_precision"
|
||||
assert merged.actual_output == "output_1"
|
||||
|
||||
|
||||
def test_stamp_and_merge_attaches_node_info() -> None:
|
||||
"""_stamp_and_merge should set node_info on every metric and remap indices."""
|
||||
results_by_index: dict[int, EvaluationItemResult] = {}
|
||||
node_info = NodeInfo(node_id="llm_1", type="llm", title="GPT-4")
|
||||
|
||||
evaluated = [
|
||||
EvaluationItemResult(
|
||||
index=0,
|
||||
metrics=[EvaluationMetric(name="faithfulness", value=0.85)],
|
||||
)
|
||||
]
|
||||
item_indices = [3]
|
||||
|
||||
_stamp_and_merge(evaluated, item_indices, node_info, results_by_index)
|
||||
|
||||
assert 3 in results_by_index
|
||||
metric = results_by_index[3].metrics[0]
|
||||
assert metric.node_info is not None
|
||||
assert metric.node_info.node_id == "llm_1"
|
||||
assert metric.node_info.type == "llm"
|
||||
494
api/uv.lock
generated
494
api/uv.lock
generated
@ -320,6 +320,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "appdirs"
|
||||
version = "1.4.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470, upload-time = "2020-05-11T07:59:51.037Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566, upload-time = "2020-05-11T07:59:49.499Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "apscheduler"
|
||||
version = "3.11.2"
|
||||
@ -1243,6 +1252,45 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/27/b822b474aaefb684d11df358d52e012699a2a8af231f9b47c54b73f280cb/databricks_sdk-0.73.0-py3-none-any.whl", hash = "sha256:a4d3cfd19357a2b459d2dc3101454d7f0d1b62865ce099c35d0c342b66ac64ff", size = 753896, upload-time = "2025-11-05T06:52:56.451Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses-json"
|
||||
version = "0.6.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "marshmallow" },
|
||||
{ name = "typing-inspect" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
{ name = "fsspec", extra = ["http"] },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "multiprocess" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "pyarrow-hotfix" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/e7/6ee66732f74e4fb1c8915e58b3c253aded777ad0fa457f3f831dd0cd09b4/datasets-2.19.2.tar.gz", hash = "sha256:eccb82fb3bb5ee26ccc6d7a15b7f1f834e2cc4e59b7cff7733a003552bad51ef", size = 2215337, upload-time = "2024-06-03T05:11:44.756Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/59/46818ebeb708234a60e42ccf409d20709e482519d2aa450b501ddbba4594/datasets-2.19.2-py3-none-any.whl", hash = "sha256:e07ff15d75b1af75c87dd96323ba2a361128d495136652f37fd62f918d17bb4e", size = 542113, upload-time = "2024-06-03T05:11:41.151Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dateparser"
|
||||
version = "1.2.2"
|
||||
@ -1267,6 +1315,45 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deepeval"
|
||||
version = "3.9.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "click" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "nest-asyncio" },
|
||||
{ name = "openai" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "portalocker" },
|
||||
{ name = "posthog" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyfiglet" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-repeat" },
|
||||
{ name = "pytest-rerunfailures" },
|
||||
{ name = "pytest-xdist" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "sentry-sdk" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "tabulate" },
|
||||
{ name = "tenacity" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
{ name = "wheel" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/64/dd6c1bce14324c211d72facd6d4deb1453321d7e32a82b7d5f1ead7ba817/deepeval-3.9.6.tar.gz", hash = "sha256:07eaad40f17a21b809bd6719f9f4ae61d8880753edddaee2abbcd4a27cfb225d", size = 614512, upload-time = "2026-04-07T15:27:41.533Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/3e/cac5dad55d9a9540b5288dbe853554638302df43c459e0f64663d5ad129a/deepeval-3.9.6-py3-none-any.whl", hash = "sha256:6a368cdd7fda7bf94ef1bcb65b40914271096dbdc54aa8988aa62bda41c68f94", size = 843386, upload-time = "2026-04-07T15:27:39.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@ -1459,6 +1546,10 @@ dev = [
|
||||
{ name = "types-ujson" },
|
||||
{ name = "xinference-client" },
|
||||
]
|
||||
evaluation = [
|
||||
{ name = "deepeval" },
|
||||
{ name = "ragas" },
|
||||
]
|
||||
storage = [
|
||||
{ name = "azure-storage-blob" },
|
||||
{ name = "bce-python-sdk" },
|
||||
@ -1756,6 +1847,10 @@ dev = [
|
||||
{ name = "types-ujson", specifier = ">=5.10.0" },
|
||||
{ name = "xinference-client", specifier = "~=2.4.0" },
|
||||
]
|
||||
evaluation = [
|
||||
{ name = "deepeval", specifier = ">=2.0.0" },
|
||||
{ name = "ragas", specifier = ">=0.2.0" },
|
||||
]
|
||||
storage = [
|
||||
{ name = "azure-storage-blob", specifier = "==12.28.0" },
|
||||
{ name = "bce-python-sdk", specifier = "~=0.9.69" },
|
||||
@ -2167,6 +2262,24 @@ dependencies = [
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "weaviate-client", specifier = "==4.20.5" }]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.3.8"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diskcache"
|
||||
version = "5.6.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diskcache-weave"
|
||||
version = "5.6.3.post1"
|
||||
@ -2546,11 +2659,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "fsspec"
|
||||
version = "2025.10.0"
|
||||
version = "2024.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/24/7f/2747c0d332b9acfa75dc84447a066fdf812b5a6b8d30472b74d309bfe8cb/fsspec-2025.10.0.tar.gz", hash = "sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59", size = 309285, upload-time = "2025-10-30T14:58:44.036Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8b/b8/e3ba21f03c00c27adc9a8cd1cab8adfb37b6024757133924a9a4eab63a83/fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9", size = 170742, upload-time = "2024-03-18T19:35:13.995Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/02/a6b21098b1d5d6249b7c5ab69dde30108a71e4e819d4a9778f1de1d5b70d/fsspec-2025.10.0-py3-none-any.whl", hash = "sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d", size = 200966, upload-time = "2025-10-30T14:58:42.53Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/6d/66d48b03460768f523da62a57a7e14e5e95fdf339d79e996ce3cecda2cdb/fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512", size = 171991, upload-time = "2024-03-18T19:35:11.259Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
http = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3333,6 +3451,28 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/ca/1172b6638d52f2d6caa2dd262ec4c811ba59eee96d54a7701930726bce18/installer-0.7.0-py3-none-any.whl", hash = "sha256:05d1933f0a5ba7d8d6296bb6d5018e7c94fa473ceb10cf198a92ccea19c27b53", size = 453838, upload-time = "2023-03-17T20:39:36.219Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instructor"
|
||||
version = "1.15.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "docstring-parser" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "jiter" },
|
||||
{ name = "openai" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-core" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "tenacity" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/a4/832cfb15420360e26d2d85bd9d5fe1e4b839d52587574d389bc31284bf6f/instructor-1.15.1.tar.gz", hash = "sha256:c72406469d9025b742e83cf0c13e914b317db2089d08d889944e74fcd659ef94", size = 69948370, upload-time = "2026-04-03T01:51:30.107Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/c8/36c5d9b80aaf40ba9a7084a8fc18c967db6bf248a4cc8d0f0816b14284be/instructor-1.15.1-py3-none-any.whl", hash = "sha256:be81d17ba2b154a04ab4720808f24f9d6b598f80992f82eaf9cc79006099cf6c", size = 178156, upload-time = "2026-04-03T01:51:23.098Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "intersystems-irispython"
|
||||
version = "5.3.2"
|
||||
@ -3442,6 +3582,27 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/03/7afcecb4242d93b684708b47fb014abdc1922a01b38c0e30f1117ae74a83/json_repair-0.59.2-py3-none-any.whl", hash = "sha256:6ca6238519c24f671bcb05d1f38a0d6a452bb4ca5af82137595c5c2f1a0fb785", size = 46918, upload-time = "2026-04-11T15:55:39.817Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonpatch"
|
||||
version = "1.33"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jsonpointer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonpointer"
|
||||
version = "3.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/18/c7/af399a2e7a67fd18d63c40c5e62d3af4e67b836a2107468b6a5ea24c4304/jsonpointer-3.1.1.tar.gz", hash = "sha256:0b801c7db33a904024f6004d526dcc53bbb8a4a0f4e32bfd10beadf60adf1900", size = 9068, upload-time = "2026-03-23T22:32:32.458Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/6a/a83720e953b1682d2d109d3c2dbb0bc9bf28cc1cbc205be4ef4be5da709d/jsonpointer-3.1.1-py3-none-any.whl", hash = "sha256:8ff8b95779d071ba472cf5bc913028df06031797532f08a7d5b602d8b2a488ca", size = 7659, upload-time = "2026-03-23T22:32:31.568Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonschema"
|
||||
version = "4.25.1"
|
||||
@ -3515,6 +3676,106 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/89/43/d9bebfc3db7dea6ec80df5cb2aad8d274dd18ec2edd6c4f21f32c237cbbb/kubernetes-33.1.0-py2.py3-none-any.whl", hash = "sha256:544de42b24b64287f7e0aa9513c93cb503f7f40eea39b20f66810011a86eabc5", size = 1941335, upload-time = "2025-06-09T21:57:56.327Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "1.2.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/3f/888a7099d2bd2917f8b0c3ffc7e347f1e664cf64267820b0b923c4f339fc/langchain-1.2.15.tar.gz", hash = "sha256:1717b6719daefae90b2728314a5e2a117ff916291e2862595b6c3d6fba33d652", size = 574732, upload-time = "2026-04-03T14:26:03.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/e8/a3b8cb0005553f6a876865073c81ef93bd7c5b18381bcb9ba4013af96ebc/langchain-1.2.15-py3-none-any.whl", hash = "sha256:e349db349cb3e9550c4044077cf90a1717691756cc236438404b23500e615874", size = 112714, upload-time = "2026-04-03T14:26:02.557Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-classic"
|
||||
version = "1.0.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langchain-text-splitters" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "sqlalchemy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/32/04/b01c09e37414bab9f209efa311502841a3c0de5bc6c35e729c8d8a9893c9/langchain_classic-1.0.3.tar.gz", hash = "sha256:168ef1dfbfb18cae5a9ff0accecc9413a5b5aa3464b53fa841561a3384b6324a", size = 10534933, upload-time = "2026-03-13T13:56:11.96Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/e6/cfdeedec0537ffbf5041773590d25beb7f2aa467cc6630e788c9c7c72c3e/langchain_classic-1.0.3-py3-none-any.whl", hash = "sha256:26df1ec9806b1fbff19d9085a747ea7d8d82d7e3fb1d25132859979de627ef79", size = 1041335, upload-time = "2026-03-13T13:56:09.677Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-community"
|
||||
version = "0.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "dataclasses-json" },
|
||||
{ name = "httpx-sse" },
|
||||
{ name = "langchain-classic" },
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "tenacity" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/53/97/a03585d42b9bdb6fbd935282d6e3348b10322a24e6ce12d0c99eb461d9af/langchain_community-0.4.1.tar.gz", hash = "sha256:f3b211832728ee89f169ddce8579b80a085222ddb4f4ed445a46e977d17b1e85", size = 33241144, upload-time = "2025-10-27T15:20:32.504Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/a4/c4fde67f193401512337456cabc2148f2c43316e445f5decd9f8806e2992/langchain_community-0.4.1-py3-none-any.whl", hash = "sha256:2135abb2c7748a35c84613108f7ebf30f8505b18c3c18305ffaecfc7651f6c6a", size = 2533285, upload-time = "2025-10-27T15:20:30.767Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.2.28"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "tenacity" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uuid-utils" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "1.1.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "openai" },
|
||||
{ name = "tiktoken" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/49/ae/1dbeb49ab8f098f78ec52e21627e705e5d7c684dc8826c2c34cc2746233a/langchain_openai-1.1.9.tar.gz", hash = "sha256:fdee25dcf4b0685d8e2f59856f4d5405431ef9e04ab53afe19e2e8360fed8234", size = 1004828, upload-time = "2026-02-10T21:03:21.615Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/52/a1/8a20d19f69d022c10d34afa42d972cc50f971b880d0eb4a828cf3dd824a8/langchain_openai-1.1.9-py3-none-any.whl", hash = "sha256:ca2482b136c45fb67c0db84a9817de675e0eb8fb2203a33914c1b7a96f273940", size = 85769, upload-time = "2026-02-10T21:03:20.333Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "1.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/85/38/14121ead61e0e75f79c3a35e5148ac7c2fe754a55f76eab3eed573269524/langchain_text_splitters-1.1.1.tar.gz", hash = "sha256:34861abe7c07d9e49d4dc852d0129e26b32738b60a74486853ec9b6d6a8e01d2", size = 279352, upload-time = "2026-02-18T23:02:42.798Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/84/66/d9e0c3b83b0ad75ee746c51ba347cacecb8d656b96e1d513f3e334d1ccab/langchain_text_splitters-1.1.1-py3-none-any.whl", hash = "sha256:5ed0d7bf314ba925041e7d7d17cd8b10f688300d5415fb26c29442f061e329dc", size = 35734, upload-time = "2026-02-18T23:02:41.913Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langdetect"
|
||||
version = "1.0.9"
|
||||
@ -3543,6 +3804,62 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/be/0a/b84e3e68a690ccfe6d64953c572772c685fcb0915b7f2ee3a87c22e388ab/langfuse-4.2.0-py3-none-any.whl", hash = "sha256:bfd760bf10fd0228f297f6369436620f76d16b589de46393d65706b27e4e4082", size = 475449, upload-time = "2026-04-10T11:55:23.624Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph"
|
||||
version = "1.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph-checkpoint" },
|
||||
{ name = "langgraph-prebuilt" },
|
||||
{ name = "langgraph-sdk" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5c/e5/d3f72ead3c7f15769d5a9c07e373628f1fbaf6cbe7735694d7085859acf6/langgraph-1.1.6.tar.gz", hash = "sha256:1783f764b08a607e9f288dbcf6da61caeb0dd40b337e5c9fb8b412341fbc0b60", size = 549634, upload-time = "2026-04-03T19:01:32.561Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/71/e6/b36ecdb3ff4ba9a290708d514bae89ebbe2f554b6abbe4642acf3fddbe51/langgraph-1.1.6-py3-none-any.whl", hash = "sha256:fdbf5f54fa5a5a4c4b09b7b5e537f1b2fa283d2f0f610d3457ddeecb479458b9", size = 169755, upload-time = "2026-04-03T19:01:30.686Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint"
|
||||
version = "4.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "ormsgpack" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b1/44/a8df45d1e8b4637e29789fa8bae1db022c953cc7ac80093cfc52e923547e/langgraph_checkpoint-4.0.1.tar.gz", hash = "sha256:b433123735df11ade28829e40ce25b9be614930cd50245ff2af60629234befd9", size = 158135, upload-time = "2026-02-27T21:06:16.092Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/4c/09a4a0c42f5d2fc38d6c4d67884788eff7fd2cfdf367fdf7033de908b4c0/langgraph_checkpoint-4.0.1-py3-none-any.whl", hash = "sha256:e3adcd7a0e0166f3b48b8cf508ce0ea366e7420b5a73aa81289888727769b034", size = 50453, upload-time = "2026-02-27T21:06:14.293Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-prebuilt"
|
||||
version = "1.0.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph-checkpoint" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/99/4c/06dac899f4945bedb0c3a1583c19484c2cc894114ea30d9a538dd270086e/langgraph_prebuilt-1.0.9.tar.gz", hash = "sha256:93de7512e9caade4b77ead92428f6215c521fdb71b8ffda8cd55f0ad814e64de", size = 165850, upload-time = "2026-04-03T14:06:37.721Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/a2/8368ac187b75e7f9d938ca075d34f116683f5cfc48d924029ee79aea147b/langgraph_prebuilt-1.0.9-py3-none-any.whl", hash = "sha256:776c8e3154a5aef5ad0e5bf3f263f2dcaab3983786cc20014b7f955d99d2d1b2", size = 35958, upload-time = "2026-04-03T14:06:36.58Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-sdk"
|
||||
version = "0.3.13"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
{ name = "orjson" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0e/db/77a45127dddcfea5e4256ba916182903e4c31dc4cfca305b8c386f0a9e53/langgraph_sdk-0.3.13.tar.gz", hash = "sha256:419ca5663eec3cec192ad194ac0647c0c826866b446073eb40f384f950986cd5", size = 196360, upload-time = "2026-04-07T20:34:18.766Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/ef/64d64e9f8eea47ce7b939aa6da6863b674c8d418647813c20111645fcc62/langgraph_sdk-0.3.13-py3-none-any.whl", hash = "sha256:aee09e345c90775f6de9d6f4c7b847cfc652e49055c27a2aed0d981af2af3bd0", size = 96668, upload-time = "2026-04-07T20:34:17.866Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.7.30"
|
||||
@ -3731,6 +4048,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "marshmallow"
|
||||
version = "3.26.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@ -3870,6 +4199,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multiprocess"
|
||||
version = "0.70.16"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "murmurhash"
|
||||
version = "1.0.15"
|
||||
@ -3940,6 +4285,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/15/dd/b3250826c29cee7816de4409a2fe5e469a68b9a89f6bfaa5eed74f05532c/mysql_connector_python-9.6.0-py2.py3-none-any.whl", hash = "sha256:44b0fb57207ebc6ae05b5b21b7968a9ed33b29187fe87b38951bad2a334d75d5", size = 480527, upload-time = "2026-02-10T12:04:36.176Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nest-asyncio"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "networkx"
|
||||
version = "3.6"
|
||||
@ -4567,6 +4921,23 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/3d/b4fefad8bdf91e0fe212eb04975aeb36ea92997269d68857efcc7eb1dda3/orjson-3.11.6-cp312-cp312-win_arm64.whl", hash = "sha256:65dfa096f4e3a5e02834b681f539a87fbe85adc82001383c0db907557f666bfc", size = 135212, upload-time = "2026-01-29T15:12:18.3Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ormsgpack"
|
||||
version = "1.12.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/12/0c/f1761e21486942ab9bb6feaebc610fa074f7c5e496e6962dea5873348077/ormsgpack-1.12.2.tar.gz", hash = "sha256:944a2233640273bee67521795a73cf1e959538e0dfb7ac635505010455e53b33", size = 39031, upload-time = "2026-01-18T20:55:28.023Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/36/16c4b1921c308a92cef3bf6663226ae283395aa0ff6e154f925c32e91ff5/ormsgpack-1.12.2-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7a29d09b64b9694b588ff2f80e9826bdceb3a2b91523c5beae1fab27d5c940e7", size = 378618, upload-time = "2026-01-18T20:55:50.835Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/68/468de634079615abf66ed13bb5c34ff71da237213f29294363beeeca5306/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b39e629fd2e1c5b2f46f99778450b59454d1f901bc507963168985e79f09c5d", size = 203186, upload-time = "2026-01-18T20:56:11.163Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/a9/d756e01961442688b7939bacd87ce13bfad7d26ce24f910f6028178b2cc8/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:958dcb270d30a7cb633a45ee62b9444433fa571a752d2ca484efdac07480876e", size = 210738, upload-time = "2026-01-18T20:56:09.181Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/ba/795b1036888542c9113269a3f5690ab53dd2258c6fb17676ac4bd44fcf94/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d379d72b6c5e964851c77cfedfb386e474adee4fd39791c2c5d9efb53505cc", size = 212569, upload-time = "2026-01-18T20:56:06.135Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/aa/bff73c57497b9e0cba8837c7e4bcab584b1a6dbc91a5dd5526784a5030c8/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8463a3fc5f09832e67bdb0e2fda6d518dc4281b133166146a67f54c08496442e", size = 387166, upload-time = "2026-01-18T20:55:36.738Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/cf/f8283cba44bcb7b14f97b6274d449db276b3a86589bdb363169b51bc12de/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:eddffb77eff0bad4e67547d67a130604e7e2dfbb7b0cde0796045be4090f35c6", size = 482498, upload-time = "2026-01-18T20:55:29.626Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/be/71e37b852d723dfcbe952ad04178c030df60d6b78eba26bfd14c9a40575e/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fcd55e5f6ba0dbce624942adf9f152062135f991a0126064889f68eb850de0dd", size = 425518, upload-time = "2026-01-18T20:55:49.556Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/0c/9803aa883d18c7ef197213cd2cbf73ba76472a11fe100fb7dab2884edf48/ormsgpack-1.12.2-cp312-cp312-win_amd64.whl", hash = "sha256:d024b40828f1dde5654faebd0d824f9cc29ad46891f626272dd5bfd7af2333a4", size = 117462, upload-time = "2026-01-18T20:55:47.726Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/9e/029e898298b2cc662f10d7a15652a53e3b525b1e7f07e21fef8536a09bb8/ormsgpack-1.12.2-cp312-cp312-win_arm64.whl", hash = "sha256:da538c542bac7d1c8f3f2a937863dba36f013108ce63e55745941dda4b75dbb6", size = 111559, upload-time = "2026-01-18T20:55:54.273Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oss2"
|
||||
version = "2.19.1"
|
||||
@ -4793,7 +5164,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "posthog"
|
||||
version = "7.0.1"
|
||||
version = "5.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "backoff" },
|
||||
@ -4801,11 +5172,10 @@ dependencies = [
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
{ name = "six" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/d4/b9afe855a8a7a1bf4459c28ae4c300b40338122dc850acabefcf2c3df24d/posthog-7.0.1.tar.gz", hash = "sha256:21150562c2630a599c1d7eac94bc5c64eb6f6acbf3ff52ccf1e57345706db05a", size = 126985, upload-time = "2025-11-15T12:44:22.465Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/48/20/60ae67bb9d82f00427946218d49e2e7e80fb41c15dc5019482289ec9ce8d/posthog-5.4.0.tar.gz", hash = "sha256:701669261b8d07cdde0276e5bc096b87f9e200e3b9589c5ebff14df658c5893c", size = 88076, upload-time = "2025-06-20T23:19:23.485Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/0c/8b6b20b0be71725e6e8a32dcd460cdbf62fe6df9bc656a650150dc98fedd/posthog-7.0.1-py3-none-any.whl", hash = "sha256:efe212d8d88a9ba80a20c588eab4baf4b1a5e90e40b551160a5603bb21e96904", size = 145234, upload-time = "2025-11-15T12:44:21.247Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/98/e480cab9a08d1c09b1c59a93dade92c1bb7544826684ff2acbfd10fcfbd4/posthog-5.4.0-py3-none-any.whl", hash = "sha256:284dfa302f64353484420b52d4ad81ff5c2c2d1d607c4e2db602ac72761831bd", size = 105364, upload-time = "2025-06-20T23:19:22.001Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -5000,6 +5370,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow-hotfix"
|
||||
version = "0.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d2/ed/c3e8677f7abf3981838c2af7b5ac03e3589b3ef94fcb31d575426abae904/pyarrow_hotfix-0.7.tar.gz", hash = "sha256:59399cd58bdd978b2e42816a4183a55c6472d4e33d183351b6069f11ed42661d", size = 9910, upload-time = "2025-04-25T10:17:06.247Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2e/c3/94ade4906a2f88bc935772f59c934013b4205e773bcb4239db114a6da136/pyarrow_hotfix-0.7-py3-none-any.whl", hash = "sha256:3236f3b5f1260f0e2ac070a55c1a7b339c4bb7267839bd2015e283234e758100", size = 7923, upload-time = "2025-04-25T10:17:05.224Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.3"
|
||||
@ -5120,6 +5499,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyfiglet"
|
||||
version = "1.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c8/e3/0a86276ad2c383ce08d76110a8eec2fe22e7051c4b8ba3fa163a0b08c428/pyfiglet-1.0.4.tar.gz", hash = "sha256:db9c9940ed1bf3048deff534ed52ff2dafbbc2cd7610b17bb5eca1df6d4278ef", size = 1560615, upload-time = "2025-08-15T18:32:47.302Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/5c/fe9f95abd5eaedfa69f31e450f7e2768bef121dbdf25bcddee2cd3087a16/pyfiglet-1.0.4-py3-none-any.whl", hash = "sha256:65b57b7a8e1dff8a67dc8e940a117238661d5e14c3e49121032bd404d9b2b39f", size = 1806118, upload-time = "2025-08-15T18:32:45.556Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.20.0"
|
||||
@ -5329,6 +5717,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "1.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-benchmark"
|
||||
version = "5.2.3"
|
||||
@ -5381,6 +5782,31 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-repeat"
|
||||
version = "0.9.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/80/d4/69e9dbb9b8266df0b157c72be32083403c412990af15c7c15f7a3fd1b142/pytest_repeat-0.9.4.tar.gz", hash = "sha256:d92ac14dfaa6ffcfe6917e5d16f0c9bc82380c135b03c2a5f412d2637f224485", size = 6488, upload-time = "2025-04-07T14:59:53.077Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/73/d4/8b706b81b07b43081bd68a2c0359fe895b74bf664b20aca8005d2bb3be71/pytest_repeat-0.9.4-py3-none-any.whl", hash = "sha256:c1738b4e412a6f3b3b9e0b8b29fcd7a423e50f87381ad9307ef6f5a8601139f3", size = 4180, upload-time = "2025-04-07T14:59:51.492Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-rerunfailures"
|
||||
version = "16.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/de/04/71e9520551fc8fe2cf5c1a1842e4e600265b0815f2016b7c27ec85688682/pytest_rerunfailures-16.1.tar.gz", hash = "sha256:c38b266db8a808953ebd71ac25c381cb1981a78ff9340a14bcb9f1b9bff1899e", size = 30889, upload-time = "2025-10-10T07:06:01.238Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/77/54/60eabb34445e3db3d3d874dc1dfa72751bfec3265bd611cb13c8b290adea/pytest_rerunfailures-16.1-py3-none-any.whl", hash = "sha256:5d11b12c0ca9a1665b5054052fcc1084f8deadd9328962745ef6b04e26382e86", size = 14093, upload-time = "2025-10-10T07:06:00.019Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-timeout"
|
||||
version = "2.4.0"
|
||||
@ -5581,6 +6007,35 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/fa/5abd82cde353f1009c068cca820195efd94e403d261b787e78ea7a9c8318/qdrant_client-1.9.0-py3-none-any.whl", hash = "sha256:ee02893eab1f642481b1ac1e38eb68ec30bab0f673bef7cc05c19fa5d2cbf43e", size = 229258, upload-time = "2024-04-22T13:35:46.81Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ragas"
|
||||
version = "0.3.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "appdirs" },
|
||||
{ name = "datasets" },
|
||||
{ name = "diskcache" },
|
||||
{ name = "gitpython" },
|
||||
{ name = "instructor" },
|
||||
{ name = "langchain" },
|
||||
{ name = "langchain-community" },
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "nest-asyncio" },
|
||||
{ name = "numpy" },
|
||||
{ name = "openai" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "rich" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5c/3f/44ffcfcd10b885c330f5e07f327edcf8965674103f81a8766a01672c3f27/ragas-0.3.2.tar.gz", hash = "sha256:a800a5326f0d5bfa086cf8832d4b1031e4b21ff063f83d7d9388ee36e1a93fe6", size = 257091, upload-time = "2025-08-19T12:04:18.655Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/1a/d08bb5fd256c375dae885e508c00636cba5692e772616e49c6d4cfdec52c/ragas-0.3.2-py3-none-any.whl", hash = "sha256:ccf4447fdc6daf69b5fafb26d8baa5095e4194a4dc87091b2a118f68322e7f1b", size = 277283, upload-time = "2025-08-19T12:04:17.329Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rapidfuzz"
|
||||
version = "3.14.3"
|
||||
@ -6876,6 +7331,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspect"
|
||||
version = "0.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mypy-extensions" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
version = "0.4.2"
|
||||
@ -7330,6 +7798,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wheel"
|
||||
version = "0.46.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/89/24/a2eb353a6edac9a0303977c4cb048134959dd2a51b48a269dfc9dde00c8a/wheel-0.46.3.tar.gz", hash = "sha256:e3e79874b07d776c40bd6033f8ddf76a7dad46a7b8aa1b2787a83083519a1803", size = 60605, upload-time = "2026-01-22T12:39:49.136Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl", hash = "sha256:4b399d56c9d9338230118d705d9737a2a468ccca63d5e813e2a4fc7815d8bc4d", size = 30557, upload-time = "2026-01-22T12:39:48.099Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.16.0"
|
||||
|
||||
@ -9,6 +9,7 @@ import {
|
||||
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
|
||||
} from '@/app/education-apply/constants'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
|
||||
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
|
||||
@ -45,6 +46,8 @@ export const AppInitializer = ({
|
||||
(async () => {
|
||||
const action = searchParams.get('action')
|
||||
|
||||
rememberCreateAppExternalAttribution({ searchParams })
|
||||
|
||||
if (oauthNewUser) {
|
||||
let utmInfo = null
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
|
||||
@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app'
|
||||
import Apps from '../index'
|
||||
|
||||
const mockUseExploreAppList = vi.fn()
|
||||
const mockTrackEvent = vi.fn()
|
||||
const mockImportDSL = vi.fn()
|
||||
const mockFetchAppDetail = vi.fn()
|
||||
const mockHandleCheckPluginDependencies = vi.fn()
|
||||
@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn()
|
||||
const mockPush = vi.fn()
|
||||
const mockToastSuccess = vi.fn()
|
||||
const mockToastError = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
let latestDebounceFn = () => {}
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
error: (...args: unknown[]) => mockToastError(...args),
|
||||
},
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
importDSL: (...args: unknown[]) => mockImportDSL(...args),
|
||||
@ -246,10 +246,9 @@ describe('Apps', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({
|
||||
template_id: 'Alpha',
|
||||
template_name: 'Alpha',
|
||||
}))
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id')
|
||||
|
||||
@ -8,7 +8,6 @@ import * as React from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppTypeSelector from '@/app/components/app/type-selector'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import AppCard from '../app-card'
|
||||
import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar'
|
||||
|
||||
@ -127,14 +127,7 @@ const Apps = ({
|
||||
icon_background,
|
||||
description,
|
||||
})
|
||||
|
||||
// Track app creation from template
|
||||
trackEvent('create_app_with_template', {
|
||||
app_mode: mode,
|
||||
template_id: currApp?.app.id,
|
||||
template_name: currApp?.app.name,
|
||||
description,
|
||||
})
|
||||
trackCreateApp({ appMode: mode })
|
||||
|
||||
setIsShowCreateModal(false)
|
||||
toast.success(t('newApp.appCreated', { ns: 'app' }))
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import type { App } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation'
|
||||
import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import CreateAppModal from '../index'
|
||||
|
||||
const ahooksMocks = vi.hoisted(() => ({
|
||||
@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
createApp: vi.fn(),
|
||||
@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({
|
||||
const mockUseRouter = vi.mocked(useRouter)
|
||||
const mockPush = vi.fn()
|
||||
const mockCreateApp = vi.mocked(createApp)
|
||||
const mockTrackEvent = vi.mocked(trackEvent)
|
||||
const mockTrackCreateApp = vi.mocked(trackCreateApp)
|
||||
const mockGetRedirection = vi.mocked(getRedirection)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
@ -178,10 +178,7 @@ describe('CreateAppModal', () => {
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
}))
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
|
||||
app_mode: AppModeEnum.ADVANCED_CHAT,
|
||||
description: '',
|
||||
})
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT })
|
||||
expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated')
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
|
||||
@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon
|
||||
import { useDebounceFn, useKeyPress } from 'ahooks'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import FullScreenModal from '@/app/components/base/fullscreen-modal'
|
||||
@ -25,6 +24,7 @@ import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import { basePath } from '@/utils/var'
|
||||
import AppIconPicker from '../../base/app-icon-picker'
|
||||
import ShortcutsName from '../../workflow/shortcuts-name'
|
||||
@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
mode: appMode,
|
||||
})
|
||||
|
||||
// Track app creation success
|
||||
trackEvent('create_app', {
|
||||
app_mode: appMode,
|
||||
description,
|
||||
})
|
||||
trackCreateApp({ appMode: app.mode })
|
||||
|
||||
toast.success(t('newApp.appCreated', { ns: 'app' }))
|
||||
onSuccess()
|
||||
|
||||
@ -2,12 +2,13 @@
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { DSLImportMode, DSLImportStatus } from '@/models/app'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index'
|
||||
|
||||
const mockPush = vi.fn()
|
||||
const mockImportDSL = vi.fn()
|
||||
const mockImportDSLConfirm = vi.fn()
|
||||
const mockTrackEvent = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
const mockHandleCheckPluginDependencies = vi.fn()
|
||||
const mockGetRedirection = vi.fn()
|
||||
const toastMocks = vi.hoisted(() => ({
|
||||
@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: (...args: unknown[]) => mockTrackEvent(...args),
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => {
|
||||
id: 'import-1',
|
||||
status: DSLImportStatus.COMPLETED,
|
||||
app_id: 'app-1',
|
||||
app_mode: 'chat',
|
||||
app_mode: AppModeEnum.CHAT,
|
||||
})
|
||||
|
||||
render(
|
||||
@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => {
|
||||
mode: DSLImportMode.YAML_URL,
|
||||
yaml_url: 'https://example.com/app.yml',
|
||||
})
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({
|
||||
creation_method: 'dsl_url',
|
||||
has_warnings: false,
|
||||
}))
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT })
|
||||
expect(handleSuccess).toHaveBeenCalledTimes(1)
|
||||
expect(handleClose).toHaveBeenCalledTimes(1)
|
||||
expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1')
|
||||
@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => {
|
||||
id: 'import-2',
|
||||
status: DSLImportStatus.COMPLETED_WITH_WARNINGS,
|
||||
app_id: 'app-2',
|
||||
app_mode: 'chat',
|
||||
app_mode: AppModeEnum.CHAT,
|
||||
})
|
||||
|
||||
render(
|
||||
@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => {
|
||||
mockImportDSLConfirm.mockResolvedValue({
|
||||
status: DSLImportStatus.COMPLETED,
|
||||
app_id: 'app-3',
|
||||
app_mode: 'workflow',
|
||||
app_mode: AppModeEnum.WORKFLOW,
|
||||
})
|
||||
|
||||
render(
|
||||
@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => {
|
||||
expect(mockImportDSLConfirm).toHaveBeenCalledWith({
|
||||
import_id: 'import-3',
|
||||
})
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW })
|
||||
})
|
||||
|
||||
it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => {
|
||||
@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => {
|
||||
id: 'import-in-flight',
|
||||
status: DSLImportStatus.COMPLETED,
|
||||
app_id: 'app-1',
|
||||
app_mode: 'chat',
|
||||
app_mode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ import {
|
||||
} from '@/service/apps'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import ShortcutsName from '../../workflow/shortcuts-name'
|
||||
import Uploader from './uploader'
|
||||
|
||||
@ -112,12 +113,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
return
|
||||
const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response
|
||||
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
|
||||
// Track app creation from DSL import
|
||||
trackEvent('create_app_with_dsl', {
|
||||
app_mode,
|
||||
creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url',
|
||||
has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS,
|
||||
})
|
||||
trackCreateApp({ appMode: app_mode })
|
||||
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
@ -179,6 +175,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
const { status, app_id, app_mode } = response
|
||||
|
||||
if (status === DSLImportStatus.COMPLETED) {
|
||||
trackCreateApp({ appMode: app_mode })
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
if (onClose)
|
||||
|
||||
@ -1,12 +1,48 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { App } from '@/models/explore'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { useContextSelector } from 'use-context-selector'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
import Apps from '../index'
|
||||
|
||||
let documentTitleCalls: string[] = []
|
||||
let educationInitCalls: number = 0
|
||||
const mockHandleImportDSL = vi.fn()
|
||||
const mockHandleImportDSLConfirm = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
const mockFetchAppDetail = vi.mocked(fetchAppDetail)
|
||||
|
||||
const mockTemplateApp: App = {
|
||||
app_id: 'template-1',
|
||||
category: 'Assistant',
|
||||
app: {
|
||||
id: 'template-1',
|
||||
mode: AppModeEnum.CHAT,
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#fff',
|
||||
icon_url: '',
|
||||
name: 'Sample App',
|
||||
description: 'Sample App',
|
||||
use_icon_as_answer_icon: false,
|
||||
},
|
||||
description: 'Sample App',
|
||||
can_trial: true,
|
||||
copyright: '',
|
||||
privacy_policy: null,
|
||||
custom_disclaimer: null,
|
||||
position: 1,
|
||||
is_listed: true,
|
||||
install_count: 0,
|
||||
installed: false,
|
||||
editable: false,
|
||||
is_agent: false,
|
||||
}
|
||||
|
||||
vi.mock('@/hooks/use-document-title', () => ({
|
||||
default: (title: string) => {
|
||||
@ -22,17 +58,80 @@ vi.mock('@/app/education-apply/hooks', () => ({
|
||||
|
||||
vi.mock('@/hooks/use-import-dsl', () => ({
|
||||
useImportDSL: () => ({
|
||||
handleImportDSL: vi.fn(),
|
||||
handleImportDSLConfirm: vi.fn(),
|
||||
handleImportDSL: mockHandleImportDSL,
|
||||
handleImportDSLConfirm: mockHandleImportDSLConfirm,
|
||||
versions: [],
|
||||
isFetching: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../list', () => ({
|
||||
default: () => {
|
||||
return React.createElement('div', { 'data-testid': 'apps-list' }, 'Apps List')
|
||||
},
|
||||
vi.mock('../list', () => {
|
||||
const MockList = () => {
|
||||
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
|
||||
return React.createElement(
|
||||
'div',
|
||||
{ 'data-testid': 'apps-list' },
|
||||
React.createElement('span', null, 'Apps List'),
|
||||
React.createElement(
|
||||
'button',
|
||||
{
|
||||
'data-testid': 'open-preview',
|
||||
'onClick': () => setShowTryAppPanel(true, {
|
||||
appId: mockTemplateApp.app_id,
|
||||
app: mockTemplateApp,
|
||||
}),
|
||||
},
|
||||
'Open Preview',
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
return { default: MockList }
|
||||
})
|
||||
|
||||
vi.mock('../../explore/try-app', () => ({
|
||||
default: ({ onCreate, onClose }: { onCreate: () => void, onClose: () => void }) => (
|
||||
<div data-testid="try-app-panel">
|
||||
<button data-testid="try-app-create" onClick={onCreate}>Create</button>
|
||||
<button data-testid="try-app-close" onClick={onClose}>Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../../explore/create-app-modal', () => ({
|
||||
default: ({ show, onConfirm, onHide }: { show: boolean, onConfirm: (payload: Record<string, string>) => Promise<void>, onHide: () => void }) => show
|
||||
? (
|
||||
<div data-testid="create-app-modal">
|
||||
<button
|
||||
data-testid="confirm-create"
|
||||
onClick={() => onConfirm({
|
||||
name: 'Created App',
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#fff',
|
||||
description: 'created from preview',
|
||||
})}
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
<button data-testid="hide-create" onClick={onHide}>Hide</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({
|
||||
default: ({ onConfirm }: { onConfirm: () => void }) => (
|
||||
<button data-testid="confirm-dsl" onClick={onConfirm}>Confirm DSL</button>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/explore', () => ({
|
||||
fetchAppDetail: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
|
||||
describe('Apps', () => {
|
||||
@ -59,6 +158,14 @@ describe('Apps', () => {
|
||||
vi.clearAllMocks()
|
||||
documentTitleCalls = []
|
||||
educationInitCalls = 0
|
||||
mockFetchAppDetail.mockResolvedValue({
|
||||
id: 'template-1',
|
||||
name: 'Sample App',
|
||||
icon: '🤖',
|
||||
icon_background: '#fff',
|
||||
mode: AppModeEnum.CHAT,
|
||||
export_data: 'yaml-content',
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
@ -116,6 +223,25 @@ describe('Apps', () => {
|
||||
)
|
||||
expect(screen.getByTestId('apps-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should track template preview creation after a successful import', async () => {
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
|
||||
renderWithClient(<Apps />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('open-preview'))
|
||||
fireEvent.click(await screen.findByTestId('try-app-create'))
|
||||
fireEvent.click(await screen.findByTestId('confirm-create'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchAppDetail).toHaveBeenCalledWith('template-1')
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styling', () => {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { CreateAppModalProps } from '../explore/create-app-modal'
|
||||
import type { TryAppSelection } from '@/types/try-app'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useEducationInit } from '@/app/education-apply/hooks'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
@ -10,6 +10,7 @@ import { useImportDSL } from '@/hooks/use-import-dsl'
|
||||
import { DSLImportMode } from '@/models/app'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import List from './list'
|
||||
|
||||
const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false })
|
||||
@ -23,6 +24,7 @@ const Apps = () => {
|
||||
useEducationInit()
|
||||
|
||||
const [currentTryAppParams, setCurrentTryAppParams] = useState<TryAppSelection | undefined>(undefined)
|
||||
const currentCreateAppModeRef = useRef<TryAppSelection['app']['app']['mode'] | null>(null)
|
||||
const currApp = currentTryAppParams?.app
|
||||
const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false)
|
||||
const hideTryAppPanel = useCallback(() => {
|
||||
@ -40,6 +42,12 @@ const Apps = () => {
|
||||
const handleShowFromTryApp = useCallback(() => {
|
||||
setIsShowCreateModal(true)
|
||||
}, [])
|
||||
const trackCurrentCreateApp = useCallback(() => {
|
||||
if (!currentCreateAppModeRef.current)
|
||||
return
|
||||
|
||||
trackCreateApp({ appMode: currentCreateAppModeRef.current })
|
||||
}, [])
|
||||
|
||||
const [controlRefreshList, setControlRefreshList] = useState(0)
|
||||
const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0)
|
||||
@ -59,11 +67,14 @@ const Apps = () => {
|
||||
|
||||
const onConfirmDSL = useCallback(async () => {
|
||||
await handleImportDSLConfirm({
|
||||
onSuccess,
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
onSuccess()
|
||||
},
|
||||
})
|
||||
}, [handleImportDSLConfirm, onSuccess])
|
||||
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = async ({
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
icon_type,
|
||||
icon,
|
||||
@ -72,9 +83,10 @@ const Apps = () => {
|
||||
}) => {
|
||||
hideTryAppPanel()
|
||||
|
||||
const { export_data } = await fetchAppDetail(
|
||||
const { export_data, mode } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
currentCreateAppModeRef.current = mode
|
||||
const payload = {
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: export_data,
|
||||
@ -86,13 +98,14 @@ const Apps = () => {
|
||||
}
|
||||
await handleImportDSL(payload, {
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
setIsShowCreateModal(false)
|
||||
},
|
||||
onPending: () => {
|
||||
setShowDSLConfirmModal(true)
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
|
||||
|
||||
return (
|
||||
<AppListContext.Provider value={{
|
||||
|
||||
@ -5,7 +5,7 @@ import * as amplitude from '@amplitude/analytics-browser'
|
||||
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
|
||||
import * as React from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config'
|
||||
import { AMPLITUDE_API_KEY } from '@/config'
|
||||
|
||||
export type IAmplitudeProps = {
|
||||
sessionReplaySampleRate?: number
|
||||
@ -54,8 +54,8 @@ const AmplitudeProvider: FC<IAmplitudeProps> = ({
|
||||
}) => {
|
||||
useEffect(() => {
|
||||
// Only enable in Saas edition with valid API key
|
||||
if (!isAmplitudeEnabled)
|
||||
return
|
||||
// if (!isAmplitudeEnabled)
|
||||
// return
|
||||
|
||||
// Initialize Amplitude
|
||||
amplitude.init(AMPLITUDE_API_KEY, {
|
||||
|
||||
@ -2,6 +2,8 @@ import { render } from '@testing-library/react'
|
||||
import PartnerStackCookieRecorder from '../cookie-recorder'
|
||||
|
||||
let isCloudEdition = true
|
||||
let psPartnerKey: string | undefined
|
||||
let psClickId: string | undefined
|
||||
|
||||
const saveOrUpdate = vi.fn()
|
||||
|
||||
@ -13,6 +15,8 @@ vi.mock('@/config', () => ({
|
||||
|
||||
vi.mock('../use-ps-info', () => ({
|
||||
default: () => ({
|
||||
psPartnerKey,
|
||||
psClickId,
|
||||
saveOrUpdate,
|
||||
}),
|
||||
}))
|
||||
@ -21,6 +25,8 @@ describe('PartnerStackCookieRecorder', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
isCloudEdition = true
|
||||
psPartnerKey = undefined
|
||||
psClickId = undefined
|
||||
})
|
||||
|
||||
it('should call saveOrUpdate once on mount when running in cloud edition', () => {
|
||||
@ -42,4 +48,16 @@ describe('PartnerStackCookieRecorder', () => {
|
||||
|
||||
expect(container.innerHTML).toBe('')
|
||||
})
|
||||
|
||||
it('should call saveOrUpdate again when partner stack query changes', () => {
|
||||
const { rerender } = render(<PartnerStackCookieRecorder />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
|
||||
|
||||
psPartnerKey = 'updated-partner'
|
||||
psClickId = 'updated-click'
|
||||
rerender(<PartnerStackCookieRecorder />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStackCookieRecorder = () => {
|
||||
const { saveOrUpdate } = usePSInfo()
|
||||
const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo()
|
||||
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
saveOrUpdate()
|
||||
}, [])
|
||||
}, [psPartnerKey, psClickId, saveOrUpdate])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStack: FC = () => {
|
||||
const { saveOrUpdate, bind } = usePSInfo()
|
||||
const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo()
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
@ -14,7 +14,7 @@ const PartnerStack: FC = () => {
|
||||
saveOrUpdate()
|
||||
// bind PartnerStack info after user logged in
|
||||
bind()
|
||||
}, [])
|
||||
}, [psPartnerKey, psClickId, saveOrUpdate, bind])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -27,6 +27,8 @@ const usePSInfo = () => {
|
||||
const domain = globalThis.location?.hostname.replace('cloud', '')
|
||||
|
||||
const saveOrUpdate = useCallback(() => {
|
||||
if (hasBind)
|
||||
return
|
||||
if (!psPartnerKey || !psClickId)
|
||||
return
|
||||
if (!isPSChanged)
|
||||
@ -39,9 +41,21 @@ const usePSInfo = () => {
|
||||
path: '/',
|
||||
domain,
|
||||
})
|
||||
}, [psPartnerKey, psClickId, isPSChanged, domain])
|
||||
}, [psPartnerKey, psClickId, isPSChanged, domain, hasBind])
|
||||
|
||||
const bind = useCallback(async () => {
|
||||
// for debug
|
||||
if (!hasBind)
|
||||
fetch("https://cloud.dify.dev/console/api/billing/debug/data", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: "bind",
|
||||
data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "",
|
||||
}),
|
||||
})
|
||||
if (psPartnerKey && psClickId && !hasBind) {
|
||||
let shouldRemoveCookie = false
|
||||
try {
|
||||
|
||||
@ -15,6 +15,7 @@ let mockIsLoading = false
|
||||
let mockIsError = false
|
||||
const mockHandleImportDSL = vi.fn()
|
||||
const mockHandleImportDSLConfirm = vi.fn()
|
||||
const mockTrackCreateApp = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-explore', () => ({
|
||||
useExploreAppList: () => ({
|
||||
@ -45,6 +46,9 @@ vi.mock('@/hooks/use-import-dsl', () => ({
|
||||
isFetching: false,
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/explore/create-app-modal', () => ({
|
||||
default: (props: CreateAppModalProps) => {
|
||||
@ -214,7 +218,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void, onPending?: () => void }) => {
|
||||
options.onPending?.()
|
||||
})
|
||||
@ -235,6 +239,9 @@ describe('AppList', () => {
|
||||
fireEvent.click(screen.getByTestId('dsl-confirm'))
|
||||
await waitFor(() => {
|
||||
expect(mockHandleImportDSLConfirm).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
expect(onSuccess).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@ -307,7 +314,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
|
||||
renderAppList(true)
|
||||
fireEvent.click(screen.getByText('explore.appCard.addToWorkspace'))
|
||||
@ -325,7 +332,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
@ -337,6 +344,9 @@ describe('AppList', () => {
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('create-app-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
|
||||
it('should cancel DSL confirm modal', async () => {
|
||||
@ -345,7 +355,7 @@ describe('AppList', () => {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' })
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => {
|
||||
options.onPending?.()
|
||||
})
|
||||
@ -385,6 +395,30 @@ describe('AppList', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should track preview source when creation starts from try app details', async () => {
|
||||
vi.useRealTimers()
|
||||
mockExploreData = {
|
||||
categories: ['Writing'],
|
||||
allList: [createApp()],
|
||||
};
|
||||
(fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT })
|
||||
mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => {
|
||||
options.onSuccess?.()
|
||||
})
|
||||
|
||||
renderAppList(true)
|
||||
|
||||
fireEvent.click(screen.getByText('explore.appCard.try'))
|
||||
fireEvent.click(screen.getByTestId('try-app-create'))
|
||||
fireEvent.click(await screen.findByTestId('confirm-create'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockTrackCreateApp).toHaveBeenCalledWith({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should close try app panel when close is clicked', () => {
|
||||
mockExploreData = {
|
||||
categories: ['Writing'],
|
||||
|
||||
@ -6,7 +6,7 @@ import type { TryAppSelection } from '@/types/try-app'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { useQueryState } from 'nuqs'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal'
|
||||
import Input from '@/app/components/base/input'
|
||||
@ -26,6 +26,7 @@ import { fetchAppDetail } from '@/service/explore'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
import { useExploreAppList } from '@/service/use-explore'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { trackCreateApp } from '@/utils/create-app-tracking'
|
||||
import TryApp from '../try-app'
|
||||
import s from './style.module.css'
|
||||
|
||||
@ -101,6 +102,7 @@ const Apps = ({
|
||||
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
|
||||
|
||||
const [currentTryApp, setCurrentTryApp] = useState<TryAppSelection | undefined>(undefined)
|
||||
const currentCreateAppModeRef = useRef<App['app']['mode'] | null>(null)
|
||||
const isShowTryAppPanel = !!currentTryApp
|
||||
const hideTryAppPanel = useCallback(() => {
|
||||
setCurrentTryApp(undefined)
|
||||
@ -112,8 +114,14 @@ const Apps = ({
|
||||
setCurrApp(currentTryApp?.app || null)
|
||||
setIsShowCreateModal(true)
|
||||
}, [currentTryApp?.app])
|
||||
const trackCurrentCreateApp = useCallback(() => {
|
||||
if (!currentCreateAppModeRef.current)
|
||||
return
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = async ({
|
||||
trackCreateApp({ appMode: currentCreateAppModeRef.current })
|
||||
}, [])
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({
|
||||
name,
|
||||
icon_type,
|
||||
icon,
|
||||
@ -122,9 +130,10 @@ const Apps = ({
|
||||
}) => {
|
||||
hideTryAppPanel()
|
||||
|
||||
const { export_data } = await fetchAppDetail(
|
||||
const { export_data, mode } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
currentCreateAppModeRef.current = mode
|
||||
const payload = {
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: export_data,
|
||||
@ -136,19 +145,23 @@ const Apps = ({
|
||||
}
|
||||
await handleImportDSL(payload, {
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
setIsShowCreateModal(false)
|
||||
},
|
||||
onPending: () => {
|
||||
setShowDSLConfirmModal(true)
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp])
|
||||
|
||||
const onConfirmDSL = useCallback(async () => {
|
||||
await handleImportDSLConfirm({
|
||||
onSuccess,
|
||||
onSuccess: () => {
|
||||
trackCurrentCreateApp()
|
||||
onSuccess?.()
|
||||
},
|
||||
})
|
||||
}, [handleImportDSLConfirm, onSuccess])
|
||||
}, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp])
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
|
||||
@ -11,6 +11,7 @@ import { validPassword } from '@/config'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { useMailRegister } from '@/service/use-common'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
|
||||
const parseUtmInfo = () => {
|
||||
@ -68,6 +69,7 @@ const ChangePasswordForm = () => {
|
||||
const { result } = res as MailRegisterResponse
|
||||
if (result === 'success') {
|
||||
const utmInfo = parseUtmInfo()
|
||||
rememberCreateAppExternalAttribution({ utmInfo })
|
||||
trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', {
|
||||
method: 'email',
|
||||
...utmInfo,
|
||||
|
||||
134
web/utils/__tests__/create-app-tracking.spec.ts
Normal file
134
web/utils/__tests__/create-app-tracking.spec.ts
Normal file
@ -0,0 +1,134 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import * as amplitude from '@/app/components/base/amplitude'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import {
|
||||
buildCreateAppEventPayload,
|
||||
extractExternalCreateAppAttribution,
|
||||
rememberCreateAppExternalAttribution,
|
||||
trackCreateApp,
|
||||
} from '../create-app-tracking'
|
||||
|
||||
describe('create-app-tracking', () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
|
||||
window.sessionStorage.clear()
|
||||
window.history.replaceState({}, '', '/apps')
|
||||
})
|
||||
|
||||
describe('extractExternalCreateAppAttribution', () => {
|
||||
it('should map campaign links to external attribution', () => {
|
||||
const attribution = extractExternalCreateAppAttribution({
|
||||
searchParams: new URLSearchParams('utm_source=x&slug=how-to-build-rag-agent'),
|
||||
})
|
||||
|
||||
expect(attribution).toEqual({
|
||||
utmSource: 'twitter/x',
|
||||
utmCampaign: 'how-to-build-rag-agent',
|
||||
})
|
||||
})
|
||||
|
||||
it('should map newsletter and blog sources to blog', () => {
|
||||
expect(extractExternalCreateAppAttribution({
|
||||
searchParams: new URLSearchParams('utm_source=newsletter'),
|
||||
})).toEqual({ utmSource: 'blog' })
|
||||
|
||||
expect(extractExternalCreateAppAttribution({
|
||||
utmInfo: { utm_source: 'dify_blog', slug: 'launch-week' },
|
||||
})).toEqual({
|
||||
utmSource: 'blog',
|
||||
utmCampaign: 'launch-week',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildCreateAppEventPayload', () => {
|
||||
it('should build original payloads with normalized app mode and timestamp', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.ADVANCED_CHAT,
|
||||
}, null, new Date(2026, 3, 13, 14, 5, 9))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'chatflow',
|
||||
time: '04-13-14:05:09',
|
||||
})
|
||||
})
|
||||
|
||||
it('should map agent mode into the canonical app mode bucket', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.AGENT_CHAT,
|
||||
}, null, new Date(2026, 3, 13, 9, 8, 7))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'agent',
|
||||
time: '04-13-09:08:07',
|
||||
})
|
||||
})
|
||||
|
||||
it('should fold legacy non-agent modes into chatflow', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.CHAT,
|
||||
}, null, new Date(2026, 3, 13, 8, 0, 1))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'chatflow',
|
||||
time: '04-13-08:00:01',
|
||||
})
|
||||
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.COMPLETION,
|
||||
}, null, new Date(2026, 3, 13, 8, 0, 2))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'chatflow',
|
||||
time: '04-13-08:00:02',
|
||||
})
|
||||
})
|
||||
|
||||
it('should map workflow mode into the workflow bucket', () => {
|
||||
expect(buildCreateAppEventPayload({
|
||||
appMode: AppModeEnum.WORKFLOW,
|
||||
}, null, new Date(2026, 3, 13, 7, 6, 5))).toEqual({
|
||||
source: 'original',
|
||||
app_mode: 'workflow',
|
||||
time: '04-13-07:06:05',
|
||||
})
|
||||
})
|
||||
|
||||
it('should prefer external attribution when present', () => {
|
||||
expect(buildCreateAppEventPayload(
|
||||
{
|
||||
appMode: AppModeEnum.WORKFLOW,
|
||||
},
|
||||
{
|
||||
utmSource: 'linkedin',
|
||||
utmCampaign: 'agent-launch',
|
||||
},
|
||||
)).toEqual({
|
||||
source: 'external',
|
||||
utm_source: 'linkedin',
|
||||
utm_campaign: 'agent-launch',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('trackCreateApp', () => {
|
||||
it('should track remembered external attribution once before falling back to internal source', () => {
|
||||
rememberCreateAppExternalAttribution({
|
||||
searchParams: new URLSearchParams('utm_source=newsletter&slug=how-to-build-rag-agent'),
|
||||
})
|
||||
|
||||
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
|
||||
|
||||
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(1, 'create_app', {
|
||||
source: 'external',
|
||||
utm_source: 'blog',
|
||||
utm_campaign: 'how-to-build-rag-agent',
|
||||
})
|
||||
|
||||
trackCreateApp({ appMode: AppModeEnum.WORKFLOW })
|
||||
|
||||
expect(amplitude.trackEvent).toHaveBeenNthCalledWith(2, 'create_app', {
|
||||
source: 'original',
|
||||
app_mode: 'workflow',
|
||||
time: expect.stringMatching(/^\d{2}-\d{2}-\d{2}:\d{2}:\d{2}$/),
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
187
web/utils/create-app-tracking.ts
Normal file
187
web/utils/create-app-tracking.ts
Normal file
@ -0,0 +1,187 @@
|
||||
import Cookies from 'js-cookie'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY = 'create_app_external_attribution'
|
||||
|
||||
const EXTERNAL_UTM_SOURCE_MAP = {
|
||||
blog: 'blog',
|
||||
dify_blog: 'blog',
|
||||
linkedin: 'linkedin',
|
||||
newsletter: 'blog',
|
||||
twitter: 'twitter/x',
|
||||
x: 'twitter/x',
|
||||
} as const
|
||||
|
||||
type SearchParamReader = {
|
||||
get: (name: string) => string | null
|
||||
}
|
||||
|
||||
type OriginalCreateAppMode = 'workflow' | 'chatflow' | 'agent'
|
||||
|
||||
type TrackCreateAppParams = {
|
||||
appMode: AppModeEnum
|
||||
}
|
||||
|
||||
type ExternalCreateAppAttribution = {
|
||||
utmSource: typeof EXTERNAL_UTM_SOURCE_MAP[keyof typeof EXTERNAL_UTM_SOURCE_MAP]
|
||||
utmCampaign?: string
|
||||
}
|
||||
|
||||
const normalizeString = (value?: string | null) => {
|
||||
const trimmed = value?.trim()
|
||||
return trimmed || undefined
|
||||
}
|
||||
|
||||
const getObjectStringValue = (value: unknown) => {
|
||||
return typeof value === 'string' ? normalizeString(value) : undefined
|
||||
}
|
||||
|
||||
const getSearchParamValue = (searchParams?: SearchParamReader | null, key?: string) => {
|
||||
if (!searchParams || !key)
|
||||
return undefined
|
||||
return normalizeString(searchParams.get(key))
|
||||
}
|
||||
|
||||
const parseJSONRecord = (value?: string | null): Record<string, unknown> | null => {
|
||||
if (!value)
|
||||
return null
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(value)
|
||||
return parsed && typeof parsed === 'object' ? parsed as Record<string, unknown> : null
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const getCookieUtmInfo = () => {
|
||||
return parseJSONRecord(Cookies.get('utm_info'))
|
||||
}
|
||||
|
||||
const mapExternalUtmSource = (value?: string) => {
|
||||
if (!value)
|
||||
return undefined
|
||||
|
||||
const normalized = value.toLowerCase()
|
||||
return EXTERNAL_UTM_SOURCE_MAP[normalized as keyof typeof EXTERNAL_UTM_SOURCE_MAP]
|
||||
}
|
||||
|
||||
const padTimeValue = (value: number) => String(value).padStart(2, '0')
|
||||
|
||||
const formatCreateAppTime = (date: Date) => {
|
||||
return `${padTimeValue(date.getMonth() + 1)}-${padTimeValue(date.getDate())}-${padTimeValue(date.getHours())}:${padTimeValue(date.getMinutes())}:${padTimeValue(date.getSeconds())}`
|
||||
}
|
||||
|
||||
const mapOriginalCreateAppMode = (appMode: AppModeEnum): OriginalCreateAppMode => {
|
||||
if (appMode === AppModeEnum.WORKFLOW)
|
||||
return 'workflow'
|
||||
|
||||
if (appMode === AppModeEnum.AGENT_CHAT)
|
||||
return 'agent'
|
||||
|
||||
return 'chatflow'
|
||||
}
|
||||
|
||||
export const extractExternalCreateAppAttribution = ({
|
||||
searchParams,
|
||||
utmInfo,
|
||||
}: {
|
||||
searchParams?: SearchParamReader | null
|
||||
utmInfo?: Record<string, unknown> | null
|
||||
}) => {
|
||||
const rawSource = getSearchParamValue(searchParams, 'utm_source') ?? getObjectStringValue(utmInfo?.utm_source)
|
||||
const mappedSource = mapExternalUtmSource(rawSource)
|
||||
|
||||
if (!mappedSource)
|
||||
return null
|
||||
|
||||
const utmCampaign = getSearchParamValue(searchParams, 'slug')
|
||||
?? getSearchParamValue(searchParams, 'utm_campaign')
|
||||
?? getObjectStringValue(utmInfo?.slug)
|
||||
?? getObjectStringValue(utmInfo?.utm_campaign)
|
||||
|
||||
return {
|
||||
utmSource: mappedSource,
|
||||
...(utmCampaign ? { utmCampaign } : {}),
|
||||
} satisfies ExternalCreateAppAttribution
|
||||
}
|
||||
|
||||
const readRememberedExternalCreateAppAttribution = (): ExternalCreateAppAttribution | null => {
|
||||
if (typeof window === 'undefined')
|
||||
return null
|
||||
|
||||
return parseJSONRecord(window.sessionStorage.getItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)) as ExternalCreateAppAttribution | null
|
||||
}
|
||||
|
||||
const writeRememberedExternalCreateAppAttribution = (attribution: ExternalCreateAppAttribution) => {
|
||||
if (typeof window === 'undefined')
|
||||
return
|
||||
|
||||
window.sessionStorage.setItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY, JSON.stringify(attribution))
|
||||
}
|
||||
|
||||
const clearRememberedExternalCreateAppAttribution = () => {
|
||||
if (typeof window === 'undefined')
|
||||
return
|
||||
|
||||
window.sessionStorage.removeItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)
|
||||
}
|
||||
|
||||
export const rememberCreateAppExternalAttribution = ({
|
||||
searchParams,
|
||||
utmInfo,
|
||||
}: {
|
||||
searchParams?: SearchParamReader | null
|
||||
utmInfo?: Record<string, unknown> | null
|
||||
} = {}) => {
|
||||
const attribution = extractExternalCreateAppAttribution({
|
||||
searchParams,
|
||||
utmInfo: utmInfo ?? getCookieUtmInfo(),
|
||||
})
|
||||
|
||||
if (attribution)
|
||||
writeRememberedExternalCreateAppAttribution(attribution)
|
||||
|
||||
return attribution
|
||||
}
|
||||
|
||||
const resolveCurrentExternalCreateAppAttribution = () => {
|
||||
if (typeof window === 'undefined')
|
||||
return null
|
||||
|
||||
return rememberCreateAppExternalAttribution({
|
||||
searchParams: new URLSearchParams(window.location.search),
|
||||
}) ?? readRememberedExternalCreateAppAttribution()
|
||||
}
|
||||
|
||||
export const buildCreateAppEventPayload = (
|
||||
params: TrackCreateAppParams,
|
||||
externalAttribution?: ExternalCreateAppAttribution | null,
|
||||
currentTime = new Date(),
|
||||
) => {
|
||||
if (externalAttribution) {
|
||||
return {
|
||||
source: 'external',
|
||||
utm_source: externalAttribution.utmSource,
|
||||
...(externalAttribution.utmCampaign ? { utm_campaign: externalAttribution.utmCampaign } : {}),
|
||||
} satisfies Record<string, string>
|
||||
}
|
||||
|
||||
return {
|
||||
source: 'original',
|
||||
app_mode: mapOriginalCreateAppMode(params.appMode),
|
||||
time: formatCreateAppTime(currentTime),
|
||||
} satisfies Record<string, string>
|
||||
}
|
||||
|
||||
export const trackCreateApp = (params: TrackCreateAppParams) => {
|
||||
const externalAttribution = resolveCurrentExternalCreateAppAttribution()
|
||||
const payload = buildCreateAppEventPayload(params, externalAttribution)
|
||||
|
||||
if (externalAttribution)
|
||||
clearRememberedExternalCreateAppAttribution()
|
||||
|
||||
trackEvent('create_app', payload)
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user