From f9b76f0f52c252d68c236be8c04e1030421f784a Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:09:40 +0800 Subject: [PATCH] feat: evaluation (#35251) Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: hj24 Co-authored-by: hj24 Co-authored-by: Joel Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: CodingOnStar Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- api/configs/feature/__init__.py | 27 + api/controllers/console/__init__.py | 11 + api/controllers/console/app/app.py | 21 + api/controllers/console/app/workflow.py | 116 ++- api/controllers/console/billing/billing.py | 39 + api/controllers/console/datasets/datasets.py | 447 +++++++- .../console/evaluation/__init__.py | 1 + .../console/evaluation/evaluation.py | 869 +++++++++++++++ api/controllers/console/snippets/payloads.py | 135 +++ .../console/snippets/snippet_workflow.py | 534 ++++++++++ .../snippet_workflow_draft_variable.py | 319 ++++++ api/controllers/console/workspace/snippets.py | 380 +++++++ api/core/app/apps/workflow/app_generator.py | 23 +- api/core/evaluation/__init__.py | 0 .../evaluation/base_evaluation_instance.py | 279 +++++ api/core/evaluation/entities/__init__.py | 0 api/core/evaluation/entities/config_entity.py | 27 + .../evaluation/entities/evaluation_entity.py | 226 ++++ .../evaluation/entities/judgment_entity.py | 96 ++ api/core/evaluation/evaluation_manager.py | 61 ++ api/core/evaluation/frameworks/__init__.py | 0 .../frameworks/deepeval/__init__.py | 1 + .../frameworks/deepeval/deepeval_evaluator.py | 299 ++++++ .../evaluation/frameworks/ragas/__init__.py | 0 .../frameworks/ragas/ragas_evaluator.py | 312 ++++++ .../frameworks/ragas/ragas_model_wrapper.py | 48 + api/core/evaluation/judgment/__init__.py | 0 api/core/evaluation/judgment/processor.py | 160 +++ api/core/evaluation/runners/__init__.py | 52 + .../runners/agent_evaluation_runner.py | 62 ++ .../runners/base_evaluation_runner.py | 51 + .../runners/llm_evaluation_runner.py | 83 ++ .../runners/retrieval_evaluation_runner.py | 61 ++ .../runners/snippet_evaluation_runner.py | 68 ++ .../runners/workflow_evaluation_runner.py | 62 ++ api/enums/quota_type.py | 188 ---- api/fields/snippet_fields.py | 45 + api/fields/workflow_app_log_fields.py | 1 + ...5e80d2380_add_customized_snippets_table.py | 83 ++ ...0001-a1b2c3d4e5f6_add_evaluation_tables.py | 116 +++ ...1721-4c60d8d3ee74_merge_migration_heads.py | 25 + api/models/__init__.py | 15 + api/models/evaluation.py | 205 ++++ api/models/snippet.py | 101 ++ api/models/workflow.py | 2 + api/pyproject.toml | 6 + api/services/app_generate_service.py | 6 +- api/services/async_workflow_service.py | 25 +- api/services/billing_service.py | 150 ++- api/services/errors/evaluation.py | 21 + api/services/evaluation_service.py | 985 ++++++++++++++++++ api/services/feature_service.py | 2 +- api/services/quota_service.py | 233 +++++ api/services/snippet_dsl_service.py | 555 ++++++++++ api/services/snippet_generate_service.py | 421 ++++++++ api/services/snippet_service.py | 608 +++++++++++ api/services/trigger/webhook_service.py | 20 +- api/services/workflow_app_service.py | 66 +- .../workflow_draft_variable_service.py | 12 +- api/services/workflow_service.py | 194 +++- api/tasks/evaluation_task.py | 541 ++++++++++ api/tasks/trigger_processing_tasks.py | 8 +- api/tasks/workflow_schedule_tasks.py | 6 +- .../services/test_app_generate_service.py | 23 +- .../trigger/test_trigger_e2e.py | 4 +- .../evaluation/judgment/test_processor.py | 145 +++ .../runners/test_base_evaluation_runner.py | 78 ++ api/tests/unit_tests/enums/__init__.py | 0 api/tests/unit_tests/enums/test_quota_type.py | 349 +++++++ .../services/test_app_generate_service.py | 12 +- .../services/test_async_workflow_service.py | 23 +- .../services/test_billing_service.py | 160 ++- .../services/test_webhook_service.py | 769 ++++++++++++++ .../unit_tests/tasks/test_evaluation_task.py | 97 ++ api/uv.lock | 494 ++++++++- web/app/components/app-initializer.tsx | 3 + .../app-list/__tests__/index.spec.tsx | 13 +- .../app/create-app-dialog/app-list/index.tsx | 11 +- .../create-app-modal/__tests__/index.spec.tsx | 13 +- .../components/app/create-app-modal/index.tsx | 8 +- .../__tests__/index.spec.tsx | 21 +- .../app/create-from-dsl-modal/index.tsx | 9 +- .../components/apps/__tests__/index.spec.tsx | 140 ++- web/app/components/apps/index.tsx | 25 +- .../base/amplitude/AmplitudeProvider.tsx | 6 +- .../__tests__/cookie-recorder.spec.tsx | 18 + .../billing/partner-stack/cookie-recorder.tsx | 4 +- .../billing/partner-stack/index.tsx | 4 +- .../billing/partner-stack/use-ps-info.ts | 16 +- .../explore/app-list/__tests__/index.spec.tsx | 42 +- web/app/components/explore/app-list/index.tsx | 25 +- web/app/signup/set-password/page.tsx | 2 + .../__tests__/create-app-tracking.spec.ts | 134 +++ web/utils/create-app-tracking.ts | 187 ++++ 94 files changed, 12010 insertions(+), 335 deletions(-) create mode 100644 api/controllers/console/evaluation/__init__.py create mode 100644 api/controllers/console/evaluation/evaluation.py create mode 100644 api/controllers/console/snippets/payloads.py create mode 100644 api/controllers/console/snippets/snippet_workflow.py create mode 100644 api/controllers/console/snippets/snippet_workflow_draft_variable.py create mode 100644 api/controllers/console/workspace/snippets.py create mode 100644 api/core/evaluation/__init__.py create mode 100644 api/core/evaluation/base_evaluation_instance.py create mode 100644 api/core/evaluation/entities/__init__.py create mode 100644 api/core/evaluation/entities/config_entity.py create mode 100644 api/core/evaluation/entities/evaluation_entity.py create mode 100644 api/core/evaluation/entities/judgment_entity.py create mode 100644 api/core/evaluation/evaluation_manager.py create mode 100644 api/core/evaluation/frameworks/__init__.py create mode 100644 api/core/evaluation/frameworks/deepeval/__init__.py create mode 100644 api/core/evaluation/frameworks/deepeval/deepeval_evaluator.py create mode 100644 api/core/evaluation/frameworks/ragas/__init__.py create mode 100644 api/core/evaluation/frameworks/ragas/ragas_evaluator.py create mode 100644 api/core/evaluation/frameworks/ragas/ragas_model_wrapper.py create mode 100644 api/core/evaluation/judgment/__init__.py create mode 100644 api/core/evaluation/judgment/processor.py create mode 100644 api/core/evaluation/runners/__init__.py create mode 100644 api/core/evaluation/runners/agent_evaluation_runner.py create mode 100644 api/core/evaluation/runners/base_evaluation_runner.py create mode 100644 api/core/evaluation/runners/llm_evaluation_runner.py create mode 100644 api/core/evaluation/runners/retrieval_evaluation_runner.py create mode 100644 api/core/evaluation/runners/snippet_evaluation_runner.py create mode 100644 api/core/evaluation/runners/workflow_evaluation_runner.py create mode 100644 api/fields/snippet_fields.py create mode 100644 api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py create mode 100644 api/migrations/versions/2026_03_03_0001-a1b2c3d4e5f6_add_evaluation_tables.py create mode 100644 api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py create mode 100644 api/models/evaluation.py create mode 100644 api/models/snippet.py create mode 100644 api/services/errors/evaluation.py create mode 100644 api/services/evaluation_service.py create mode 100644 api/services/quota_service.py create mode 100644 api/services/snippet_dsl_service.py create mode 100644 api/services/snippet_generate_service.py create mode 100644 api/services/snippet_service.py create mode 100644 api/tasks/evaluation_task.py create mode 100644 api/tests/unit_tests/core/evaluation/judgment/test_processor.py create mode 100644 api/tests/unit_tests/core/evaluation/runners/test_base_evaluation_runner.py create mode 100644 api/tests/unit_tests/enums/__init__.py create mode 100644 api/tests/unit_tests/enums/test_quota_type.py create mode 100644 api/tests/unit_tests/tasks/test_evaluation_task.py create mode 100644 web/utils/__tests__/create-app-tracking.spec.ts create mode 100644 web/utils/create-app-tracking.ts diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d37cff63e9..78862c1384 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1366,6 +1366,32 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings): ) +class EvaluationConfig(BaseSettings): + """ + Configuration for evaluation runtime + """ + + EVALUATION_FRAMEWORK: str = Field( + description="Evaluation framework to use (ragas/deepeval/none)", + default="none", + ) + + EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field( + description="Maximum number of concurrent evaluation runs per tenant", + default=3, + ) + + EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field( + description="Maximum number of rows allowed in an evaluation dataset", + default=500, + ) + + EVALUATION_TASK_TIMEOUT: PositiveInt = Field( + description="Timeout in seconds for a single evaluation task", + default=3600, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1378,6 +1404,7 @@ class FeatureConfig( MarketplaceConfig, DataSetConfig, EndpointConfig, + EvaluationConfig, FileAccessConfig, FileUploadConfig, HttpConfig, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d624b10b22..e952e33465 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -107,6 +107,9 @@ from .datasets.rag_pipeline import ( rag_pipeline_workflow, ) +# Import evaluation controllers +from .evaluation import evaluation + # Import explore controllers from .explore import ( banner, @@ -117,6 +120,9 @@ from .explore import ( trial, ) +# Import snippet controllers +from .snippets import snippet_workflow, snippet_workflow_draft_variable + # Import tag controllers from .tag import tags @@ -130,6 +136,7 @@ from .workspace import ( model_providers, models, plugin, + snippets, tool_providers, trigger_providers, workspace, @@ -167,6 +174,7 @@ __all__ = [ "datasource_content_preview", "email_register", "endpoint", + "evaluation", "extension", "external", "feature", @@ -201,6 +209,9 @@ __all__ = [ "saved_message", "setup", "site", + "snippet_workflow", + "snippet_workflow_draft_variable", + "snippets", "spec", "statistic", "tags", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ac0682486b..06810d7094 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -329,6 +329,7 @@ class AppPartial(ResponseModel): create_user_name: str | None = None author_name: str | None = None has_draft_trigger: bool | None = None + workflow_type: str | None = None @computed_field(return_type=str | None) # type: ignore @property @@ -363,6 +364,7 @@ class AppDetail(ResponseModel): updated_by: str | None = None updated_at: int | None = None access_mode: str | None = None + workflow_type: str | None = None tags: list[Tag] = Field(default_factory=list) @field_validator("created_at", "updated_at", mode="before") @@ -505,6 +507,17 @@ class AppListApi(Resource): for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids + workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id] + workflow_type_map: dict[str, str] = {} + if workflow_ids: + rows = db.session.execute( + select(Workflow.id, Workflow.type).where(Workflow.id.in_(workflow_ids)) + ).all() + workflow_type_map = {str(row.id): row.type for row in rows} + + for app in app_pagination.items: + app.workflow_type = workflow_type_map.get(str(app.workflow_id)) if app.workflow_id else None + pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True) return pagination_model.model_dump(mode="json"), 200 @@ -551,6 +564,14 @@ class AppApi(Resource): app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode + if app_model.workflow_id: + row = db.session.execute( + select(Workflow.type).where(Workflow.id == app_model.workflow_id) + ).scalar() + app_model.workflow_type = row if row else None + else: + app_model.workflow_type = None + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) return response_model.model_dump(mode="json") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5e6ff87d62..a149960a23 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from flask import abort, request from flask_restx import Resource, fields, marshal, marshal_with @@ -46,7 +46,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_account_with_tenant, login_required from models import App from models.model import AppMode -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowType from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -150,6 +150,24 @@ class ConvertToWorkflowPayload(BaseModel): icon_background: str | None = None +class WorkflowListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=10, ge=1, le=100) + user_id: str | None = None + named_only: bool = False + keyword: str | None = Field(default=None, max_length=255) + + +class WorkflowUpdatePayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + +class WorkflowTypeConvertQuery(BaseModel): + target_type: Literal["workflow", "evaluation"] + + + class DraftWorkflowTriggerRunPayload(BaseModel): node_id: str @@ -173,6 +191,7 @@ reg(DefaultBlockConfigQuery) reg(ConvertToWorkflowPayload) reg(WorkflowListQuery) reg(WorkflowUpdatePayload) +reg(WorkflowTypeConvertQuery) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) @@ -845,6 +864,54 @@ class PublishedWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/publish/evaluation") +class EvaluationPublishedWorkflowApi(Resource): + @console_ns.doc("publish_evaluation_workflow") + @console_ns.doc(description="Publish draft workflow as evaluation workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__]) + @console_ns.response(200, "Evaluation workflow published successfully") + @console_ns.response(400, "Invalid workflow or unsupported node type") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + """ + Publish draft workflow as evaluation workflow. + + Evaluation workflows cannot include trigger or human-input nodes. + """ + current_user, _ = current_account_with_tenant() + args = PublishWorkflowPayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + with Session(db.engine) as session: + workflow = workflow_service.publish_evaluation_workflow( + session=session, + app_model=app_model, + account=current_user, + marked_name=args.marked_name or "", + marked_comment=args.marked_comment or "", + ) + + # Keep workflow_id aligned with the latest published workflow. + app_model_in_session = session.get(App, app_model.id) + if app_model_in_session: + app_model_in_session.workflow_id = workflow.id + app_model_in_session.updated_by = current_user.id + app_model_in_session.updated_at = naive_utc_now() + + workflow_created_at = TimestampField().format(workflow.created_at) + session.commit() + + return { + "result": "success", + "created_at": workflow_created_at, + } + + @console_ns.route("/apps//workflows/default-workflow-block-configs") class DefaultBlockConfigsApi(Resource): @console_ns.doc("get_default_block_configs") @@ -1016,6 +1083,51 @@ class DraftWorkflowRestoreApi(Resource): } +@console_ns.route("/apps//workflows/convert-type") +class WorkflowTypeConvertApi(Resource): + @console_ns.doc("convert_published_workflow_type") + @console_ns.doc(description="Convert current effective published workflow type in-place") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__]) + @console_ns.response(200, "Workflow type converted successfully") + @console_ns.response(400, "Invalid workflow type or unsupported workflow graph") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + target_type = WorkflowType.value_of(args.target_type) + + workflow_service = WorkflowService() + with Session(db.engine) as session: + try: + workflow = workflow_service.convert_published_workflow_type( + session=session, + app_model=app_model, + target_type=target_type, + account=current_user, + ) + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + except IsDraftWorkflowError as exc: + raise BadRequest(str(exc)) from exc + except ValueError as exc: + raise BadRequest(str(exc)) from exc + + session.commit() + + return { + "result": "success", + "workflow_id": workflow.id, + "type": workflow.type.value, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): @console_ns.doc("update_workflow_by_id") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 45de338559..ce2870c82e 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,4 +1,6 @@ import base64 +import json +from datetime import UTC, datetime, timedelta from typing import Literal from flask import request @@ -10,6 +12,7 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from enums.cloud_plan import CloudPlan +from extensions.ext_redis import redis_client from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService @@ -77,3 +80,39 @@ class PartnerTenants(Resource): raise BadRequest("Invalid partner information") return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id) + + +_DEBUG_KEY = "billing:debug" +_DEBUG_TTL = timedelta(days=7) + + +class DebugDataPayload(BaseModel): + type: str = Field(..., min_length=1, description="Data type key") + data: str = Field(..., min_length=1, description="Data value to append") + + +@console_ns.route("/billing/debug/data") +class DebugData(Resource): + def post(self): + body = DebugDataPayload.model_validate(request.get_json(force=True)) + item = json.dumps({ + "type": body.type, + "data": body.data, + "createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"), + }) + redis_client.lpush(_DEBUG_KEY, item) + redis_client.expire(_DEBUG_KEY, _DEBUG_TTL) + return {"result": "ok"}, 201 + + def get(self): + recent = request.args.get("recent", 10, type=int) + items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1) + return { + "data": [ + json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items + ] + } + + def delete(self): + redis_client.delete(_DEBUG_KEY) + return {"result": "ok"} diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b2a905366a..14ca27acbd 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,11 +1,14 @@ +import json from typing import Any, cast +from urllib.parse import quote -from flask import request +from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select -from werkzeug.exceptions import Forbidden, NotFound +from sqlalchemy.orm import Session +from werkzeug.exceptions import BadRequest, Forbidden, NotFound import services from configs import dify_config @@ -22,6 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest from core.indexing_runner import IndexingRunner from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType @@ -30,6 +34,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db +from extensions.ext_storage import storage from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( content_fields, @@ -50,12 +55,19 @@ from fields.dataset_fields import ( ) from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required -from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService +from services.errors.evaluation import ( + EvaluationDatasetInvalidError, + EvaluationFrameworkNotConfiguredError, + EvaluationMaxConcurrentRunsError, + EvaluationNotFoundError, +) +from services.evaluation_service import EvaluationService # Register models for flask_restx to avoid dict type issues in Swagger dataset_base_model = get_or_create_model("DatasetBase", dataset_fields) @@ -983,3 +995,432 @@ class DatasetAutoDisableLogApi(Resource): if dataset is None: raise NotFound("Dataset not found.") return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 + + +# ---- Knowledge Base Retrieval Evaluation ---- + + +def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]: + return { + "id": run.id, + "tenant_id": run.tenant_id, + "target_type": run.target_type, + "target_id": run.target_id, + "evaluation_config_id": run.evaluation_config_id, + "status": run.status, + "dataset_file_id": run.dataset_file_id, + "result_file_id": run.result_file_id, + "total_items": run.total_items, + "completed_items": run.completed_items, + "failed_items": run.failed_items, + "progress": run.progress, + "metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {}, + "error": run.error, + "created_by": run.created_by, + "started_at": int(run.started_at.timestamp()) if run.started_at else None, + "completed_at": int(run.completed_at.timestamp()) if run.completed_at else None, + "created_at": int(run.created_at.timestamp()) if run.created_at else None, + } + + +def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]: + return { + "id": item.id, + "item_index": item.item_index, + "inputs": item.inputs_dict, + "expected_output": item.expected_output, + "actual_output": item.actual_output, + "metrics": item.metrics_list, + "judgment": item.judgment_dict, + "metadata": item.metadata_dict, + "error": item.error, + "overall_score": item.overall_score, + } + + +@console_ns.route("/datasets//evaluation/template/download") +class DatasetEvaluationTemplateDownloadApi(Resource): + @console_ns.doc("download_dataset_evaluation_template") + @console_ns.response(200, "Template file streamed as XLSX attachment") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id): + """Download evaluation dataset template for knowledge base retrieval.""" + current_user, _ = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template() + encoded_filename = quote(filename) + response = Response( + xlsx_content, + mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + response.headers["Content-Length"] = str(len(xlsx_content)) + return response + + +@console_ns.route("/datasets//evaluation") +class DatasetEvaluationDetailApi(Resource): + @console_ns.doc("get_dataset_evaluation_config") + @console_ns.response(200, "Evaluation configuration retrieved") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + """Get evaluation configuration for the knowledge base.""" + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + with Session(db.engine, expire_on_commit=False) as session: + config = EvaluationService.get_evaluation_config( + session, current_tenant_id, "dataset", dataset_id_str + ) + + if config is None: + return { + "evaluation_model": None, + "evaluation_model_provider": None, + "default_metrics": None, + "customized_metrics": None, + "judgment_config": None, + } + + return { + "evaluation_model": config.evaluation_model, + "evaluation_model_provider": config.evaluation_model_provider, + "default_metrics": config.default_metrics_list, + "customized_metrics": config.customized_metrics_dict, + "judgment_config": config.judgment_config_dict, + } + + @console_ns.doc("save_dataset_evaluation_config") + @console_ns.response(200, "Evaluation configuration saved") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + def put(self, dataset_id): + """Save evaluation configuration for the knowledge base.""" + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + body = request.get_json(force=True) + try: + config_data = EvaluationConfigData.model_validate(body) + except Exception as e: + raise BadRequest(f"Invalid request body: {e}") + + with Session(db.engine, expire_on_commit=False) as session: + config = EvaluationService.save_evaluation_config( + session=session, + tenant_id=current_tenant_id, + target_type="dataset", + target_id=dataset_id_str, + account_id=str(current_user.id), + data=config_data, + ) + + return { + "evaluation_model": config.evaluation_model, + "evaluation_model_provider": config.evaluation_model_provider, + "default_metrics": config.default_metrics_list, + "customized_metrics": config.customized_metrics_dict, + "judgment_config": config.judgment_config_dict, + } + + +@console_ns.route("/datasets//evaluation/run") +class DatasetEvaluationRunApi(Resource): + @console_ns.doc("start_dataset_evaluation_run") + @console_ns.response(200, "Evaluation run started") + @console_ns.response(400, "Invalid request") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id): + """Start an evaluation run for the knowledge base retrieval.""" + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + body = request.get_json(force=True) + if not body: + raise BadRequest("Request body is required.") + + try: + run_request = EvaluationRunRequest.model_validate(body) + except Exception as e: + raise BadRequest(f"Invalid request body: {e}") + + upload_file = ( + db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first() + ) + if not upload_file: + raise NotFound("Dataset file not found.") + + try: + dataset_content = storage.load_once(upload_file.key) + except Exception: + raise BadRequest("Failed to read dataset file.") + + if not dataset_content: + raise BadRequest("Dataset file is empty.") + + try: + with Session(db.engine, expire_on_commit=False) as session: + evaluation_run = EvaluationService.start_evaluation_run( + session=session, + tenant_id=current_tenant_id, + target_type=EvaluationTargetType.KNOWLEDGE_BASE, + target_id=dataset_id_str, + account_id=str(current_user.id), + dataset_file_content=dataset_content, + run_request=run_request, + ) + return _serialize_dataset_evaluation_run(evaluation_run), 200 + except EvaluationFrameworkNotConfiguredError as e: + return {"message": str(e.description)}, 400 + except EvaluationNotFoundError as e: + return {"message": str(e.description)}, 404 + except EvaluationMaxConcurrentRunsError as e: + return {"message": str(e.description)}, 429 + except EvaluationDatasetInvalidError as e: + return {"message": str(e.description)}, 400 + + +@console_ns.route("/datasets//evaluation/logs") +class DatasetEvaluationLogsApi(Resource): + @console_ns.doc("get_dataset_evaluation_logs") + @console_ns.response(200, "Evaluation logs retrieved") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + """Get evaluation run history for the knowledge base.""" + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 20, type=int) + + with Session(db.engine, expire_on_commit=False) as session: + runs, total = EvaluationService.get_evaluation_runs( + session=session, + tenant_id=current_tenant_id, + target_type="dataset", + target_id=dataset_id_str, + page=page, + page_size=page_size, + ) + + return { + "data": [_serialize_dataset_evaluation_run(run) for run in runs], + "total": total, + "page": page, + "page_size": page_size, + } + + +@console_ns.route("/datasets//evaluation/runs/") +class DatasetEvaluationRunDetailApi(Resource): + @console_ns.doc("get_dataset_evaluation_run_detail") + @console_ns.response(200, "Evaluation run detail retrieved") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset or run not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, run_id): + """Get evaluation run detail including per-item results.""" + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + run_id_str = str(run_id) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 50, type=int) + + try: + with Session(db.engine, expire_on_commit=False) as session: + run = EvaluationService.get_evaluation_run_detail( + session=session, + tenant_id=current_tenant_id, + run_id=run_id_str, + ) + items, total_items = EvaluationService.get_evaluation_run_items( + session=session, + run_id=run_id_str, + page=page, + page_size=page_size, + ) + return { + "run": _serialize_dataset_evaluation_run(run), + "items": { + "data": [_serialize_dataset_evaluation_run_item(item) for item in items], + "total": total_items, + "page": page, + "page_size": page_size, + }, + } + except EvaluationNotFoundError as e: + return {"message": str(e.description)}, 404 + + +@console_ns.route("/datasets//evaluation/runs//cancel") +class DatasetEvaluationRunCancelApi(Resource): + @console_ns.doc("cancel_dataset_evaluation_run") + @console_ns.response(200, "Evaluation run cancelled") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset or run not found") + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id, run_id): + """Cancel a running knowledge base evaluation.""" + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + run_id_str = str(run_id) + try: + with Session(db.engine, expire_on_commit=False) as session: + run = EvaluationService.cancel_evaluation_run( + session=session, + tenant_id=current_tenant_id, + run_id=run_id_str, + ) + return _serialize_dataset_evaluation_run(run) + except EvaluationNotFoundError as e: + return {"message": str(e.description)}, 404 + except ValueError as e: + return {"message": str(e)}, 400 + + +@console_ns.route("/datasets//evaluation/metrics") +class DatasetEvaluationMetricsApi(Resource): + @console_ns.doc("get_dataset_evaluation_metrics") + @console_ns.response(200, "Available retrieval metrics retrieved") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + """Get available evaluation metrics for knowledge base retrieval.""" + current_user, _ = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + return { + "metrics": EvaluationService.get_supported_metrics(EvaluationCategory.KNOWLEDGE_BASE) + } + + +@console_ns.route("/datasets//evaluation/files/") +class DatasetEvaluationFileDownloadApi(Resource): + @console_ns.doc("download_dataset_evaluation_file") + @console_ns.response(200, "File download URL generated") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset or file not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, file_id): + """Download evaluation test file or result file for the knowledge base.""" + from core.workflow.file import helpers as file_helpers + + current_user, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + file_id_str = str(file_id) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(UploadFile).where( + UploadFile.id == file_id_str, + UploadFile.tenant_id == current_tenant_id, + ) + upload_file = session.execute(stmt).scalar_one_or_none() + + if not upload_file: + raise NotFound("File not found.") + + download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "mime_type": upload_file.mime_type, + "created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None, + "download_url": download_url, + } diff --git a/api/controllers/console/evaluation/__init__.py b/api/controllers/console/evaluation/__init__.py new file mode 100644 index 0000000000..65c6eacd84 --- /dev/null +++ b/api/controllers/console/evaluation/__init__.py @@ -0,0 +1 @@ +# Evaluation controller module diff --git a/api/controllers/console/evaluation/evaluation.py b/api/controllers/console/evaluation/evaluation.py new file mode 100644 index 0000000000..31490020c3 --- /dev/null +++ b/api/controllers/console/evaluation/evaluation.py @@ -0,0 +1,869 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union +from urllib.parse import quote + +from flask import Response, request +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.workflow import WorkflowListQuery +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, +) +from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest +from extensions.ext_database import db +from extensions.ext_storage import storage +from fields.member_fields import simple_account_fields +from graphon.file import helpers as file_helpers +from libs.helper import TimestampField +from libs.login import current_account_with_tenant, login_required +from models import App, Dataset +from models.model import UploadFile +from models.snippet import CustomizedSnippet +from services.errors.evaluation import ( + EvaluationDatasetInvalidError, + EvaluationFrameworkNotConfiguredError, + EvaluationMaxConcurrentRunsError, + EvaluationNotFoundError, +) +from services.evaluation_service import EvaluationService +from services.workflow_service import WorkflowService + +if TYPE_CHECKING: + from models.evaluation import EvaluationRun, EvaluationRunItem + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + +# Valid evaluation target types +EVALUATE_TARGET_TYPES = {"app", "snippets"} + + +class VersionQuery(BaseModel): + """Query parameters for version endpoint.""" + + version: str + + +register_schema_models( + console_ns, + VersionQuery, +) + + +# Response field definitions +file_info_fields = { + "id": fields.String, + "name": fields.String, +} + +evaluation_log_fields = { + "created_at": TimestampField, + "created_by": fields.String, + "test_file": fields.Nested( + console_ns.model( + "EvaluationTestFile", + file_info_fields, + ) + ), + "result_file": fields.Nested( + console_ns.model( + "EvaluationResultFile", + file_info_fields, + ), + allow_null=True, + ), + "version": fields.String, +} + +evaluation_log_list_model = console_ns.model( + "EvaluationLogList", + { + "data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))), + }, +) + +evaluation_default_metric_node_info_fields = { + "node_id": fields.String, + "type": fields.String, + "title": fields.String, +} +evaluation_default_metric_item_fields = { + "metric": fields.String, + "value_type": fields.String, + "node_info_list": fields.List( + fields.Nested( + console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields), + ), + ), +} + +customized_metrics_fields = { + "evaluation_workflow_id": fields.String, + "input_fields": fields.Raw, + "output_fields": fields.Raw, +} + +judgment_condition_fields = { + "variable_selector": fields.List(fields.String), + "comparison_operator": fields.String, + "value": fields.String, +} + +judgment_config_fields = { + "logical_operator": fields.String, + "conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))), +} + +evaluation_detail_fields = { + "evaluation_model": fields.String, + "evaluation_model_provider": fields.String, + "default_metrics": fields.List( + fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)), + allow_null=True, + ), + "customized_metrics": fields.Nested( + console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields), + allow_null=True, + ), + "judgment_config": fields.Nested( + console_ns.model("EvaluationJudgmentConfig", judgment_config_fields), + allow_null=True, + ), +} + +evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields) + +available_evaluation_workflow_list_fields = { + "id": fields.String, + "app_id": fields.String, + "app_name": fields.String, + "type": fields.String, + "version": fields.String, + "marked_name": fields.String, + "marked_comment": fields.String, + "hash": fields.String, + "created_by": fields.Nested(simple_account_fields), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, allow_null=True), + "updated_at": TimestampField, +} + +available_evaluation_workflow_pagination_fields = { + "items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)), + "page": fields.Integer, + "limit": fields.Integer, + "has_more": fields.Boolean, +} + +available_evaluation_workflow_pagination_model = console_ns.model( + "AvailableEvaluationWorkflowPagination", + available_evaluation_workflow_pagination_fields, +) + +evaluation_default_metrics_response_model = console_ns.model( + "EvaluationDefaultMetricsResponse", + { + "default_metrics": fields.List( + fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)), + ), + }, +) + + +def get_evaluation_target(view_func: Callable[P, R]): + """ + Decorator to resolve polymorphic evaluation target (app or snippet). + + Validates the target_type parameter and fetches the corresponding + model (App or CustomizedSnippet) with tenant isolation. + """ + + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + target_type = kwargs.get("evaluate_target_type") + target_id = kwargs.get("evaluate_target_id") + + if target_type not in EVALUATE_TARGET_TYPES: + raise NotFound(f"Invalid evaluation target type: {target_type}") + + _, current_tenant_id = current_account_with_tenant() + + target_id = str(target_id) + + # Remove path parameters + del kwargs["evaluate_target_type"] + del kwargs["evaluate_target_id"] + + target: Union[App, CustomizedSnippet, Dataset] | None = None + + if target_type == "app": + target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first() + elif target_type == "snippets": + target = ( + db.session.query(CustomizedSnippet) + .where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id) + .first() + ) + elif target_type == "knowledge": + target = (db.session.query(Dataset) + .where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id) + .first()) + + if not target: + raise NotFound(f"{str(target_type)} not found") + + kwargs["target"] = target + kwargs["target_type"] = target_type + + return view_func(*args, **kwargs) + + return decorated_view + + +@console_ns.route("///dataset-template/download") +class EvaluationDatasetTemplateDownloadApi(Resource): + @console_ns.doc("download_evaluation_dataset_template") + @console_ns.response(200, "Template file streamed as XLSX attachment") + @console_ns.response(400, "Invalid target type or excluded app mode") + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + @edit_permission_required + def post(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Download evaluation dataset template. + + Generates an XLSX template based on the target's input parameters + and streams it directly as a file attachment. + """ + try: + xlsx_content, filename = EvaluationService.generate_dataset_template( + target=target, + target_type=target_type, + ) + except ValueError as e: + return {"message": str(e)}, 400 + + encoded_filename = quote(filename) + response = Response( + xlsx_content, + mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + response.headers["Content-Length"] = str(len(xlsx_content)) + return response + + +@console_ns.route("///evaluation") +class EvaluationDetailApi(Resource): + @console_ns.doc("get_evaluation_detail") + @console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model) + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get evaluation configuration for the target. + + Returns evaluation configuration including model settings, + metrics config, and judgement conditions. + """ + _, current_tenant_id = current_account_with_tenant() + + with Session(db.engine, expire_on_commit=False) as session: + config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id)) + + if config is None: + return { + "evaluation_model": None, + "evaluation_model_provider": None, + "default_metrics": None, + "customized_metrics": None, + "judgment_config": None, + } + + return { + "evaluation_model": config.evaluation_model, + "evaluation_model_provider": config.evaluation_model_provider, + "default_metrics": config.default_metrics_list, + "customized_metrics": config.customized_metrics_dict, + "judgment_config": config.judgment_config_dict, + } + + @console_ns.doc("save_evaluation_detail") + @console_ns.response(200, "Evaluation configuration saved successfully") + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + @edit_permission_required + def put(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Save evaluation configuration for the target. + """ + current_account, current_tenant_id = current_account_with_tenant() + body = request.get_json(force=True) + + try: + config_data = EvaluationConfigData.model_validate(body) + except Exception as e: + raise BadRequest(f"Invalid request body: {e}") + + with Session(db.engine, expire_on_commit=False) as session: + config = EvaluationService.save_evaluation_config( + session=session, + tenant_id=current_tenant_id, + target_type=target_type, + target_id=str(target.id), + account_id=str(current_account.id), + data=config_data, + ) + + return { + "evaluation_model": config.evaluation_model, + "evaluation_model_provider": config.evaluation_model_provider, + "default_metrics": config.default_metrics_list, + "customized_metrics": config.customized_metrics_dict, + "judgment_config": config.judgment_config_dict, + } + + +@console_ns.route("///evaluation/logs") +class EvaluationLogsApi(Resource): + @console_ns.doc("get_evaluation_logs") + @console_ns.response(200, "Evaluation logs retrieved successfully") + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get evaluation run history for the target. + + Returns a paginated list of evaluation runs. + """ + _, current_tenant_id = current_account_with_tenant() + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 20, type=int) + + with Session(db.engine, expire_on_commit=False) as session: + runs, total = EvaluationService.get_evaluation_runs( + session=session, + tenant_id=current_tenant_id, + target_type=target_type, + target_id=str(target.id), + page=page, + page_size=page_size, + ) + + return { + "data": [_serialize_evaluation_run(run) for run in runs], + "total": total, + "page": page, + "page_size": page_size, + } + + +@console_ns.route("///evaluation/run") +class EvaluationRunApi(Resource): + @console_ns.doc("start_evaluation_run") + @console_ns.response(200, "Evaluation run started") + @console_ns.response(400, "Invalid request") + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + @edit_permission_required + def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str): + """ + Start an evaluation run. + + Expects JSON body with: + - file_id: uploaded dataset file ID + - evaluation_model: evaluation model name + - evaluation_model_provider: evaluation model provider + - default_metrics: list of default metric objects + - customized_metrics: customized metrics object (optional) + - judgment_config: judgment conditions config (optional) + """ + current_account, current_tenant_id = current_account_with_tenant() + + body = request.get_json(force=True) + if not body: + raise BadRequest("Request body is required.") + + # Validate and parse request body + try: + run_request = EvaluationRunRequest.model_validate(body) + except Exception as e: + raise BadRequest(f"Invalid request body: {e}") + + # Load dataset file + upload_file = ( + db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first() + ) + if not upload_file: + raise NotFound("Dataset file not found.") + + try: + dataset_content = storage.load_once(upload_file.key) + except Exception: + raise BadRequest("Failed to read dataset file.") + + if not dataset_content: + raise BadRequest("Dataset file is empty.") + + try: + with Session(db.engine, expire_on_commit=False) as session: + evaluation_run = EvaluationService.start_evaluation_run( + session=session, + tenant_id=current_tenant_id, + target_type=target_type, + target_id=str(target.id), + account_id=str(current_account.id), + dataset_file_content=dataset_content, + run_request=run_request, + ) + return _serialize_evaluation_run(evaluation_run), 200 + except EvaluationFrameworkNotConfiguredError as e: + return {"message": str(e.description)}, 400 + except EvaluationNotFoundError as e: + return {"message": str(e.description)}, 404 + except EvaluationMaxConcurrentRunsError as e: + return {"message": str(e.description)}, 429 + except EvaluationDatasetInvalidError as e: + return {"message": str(e.description)}, 400 + + +@console_ns.route("///evaluation/runs/") +class EvaluationRunDetailApi(Resource): + @console_ns.doc("get_evaluation_run_detail") + @console_ns.response(200, "Evaluation run detail retrieved") + @console_ns.response(404, "Run not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str): + """ + Get evaluation run detail including items. + """ + _, current_tenant_id = current_account_with_tenant() + run_id = str(run_id) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 50, type=int) + + try: + with Session(db.engine, expire_on_commit=False) as session: + run = EvaluationService.get_evaluation_run_detail( + session=session, + tenant_id=current_tenant_id, + run_id=run_id, + ) + items, total_items = EvaluationService.get_evaluation_run_items( + session=session, + run_id=run_id, + page=page, + page_size=page_size, + ) + + return { + "run": _serialize_evaluation_run(run), + "items": { + "data": [_serialize_evaluation_run_item(item) for item in items], + "total": total_items, + "page": page, + "page_size": page_size, + }, + } + except EvaluationNotFoundError as e: + return {"message": str(e.description)}, 404 + + +@console_ns.route("///evaluation/runs//cancel") +class EvaluationRunCancelApi(Resource): + @console_ns.doc("cancel_evaluation_run") + @console_ns.response(200, "Evaluation run cancelled") + @console_ns.response(404, "Run not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + @edit_permission_required + def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str): + """Cancel a running evaluation.""" + _, current_tenant_id = current_account_with_tenant() + run_id = str(run_id) + + try: + with Session(db.engine, expire_on_commit=False) as session: + run = EvaluationService.cancel_evaluation_run( + session=session, + tenant_id=current_tenant_id, + run_id=run_id, + ) + return _serialize_evaluation_run(run) + except EvaluationNotFoundError as e: + return {"message": str(e.description)}, 404 + except ValueError as e: + return {"message": str(e)}, 400 + + +@console_ns.route("///evaluation/metrics") +class EvaluationMetricsApi(Resource): + @console_ns.doc("get_evaluation_metrics") + @console_ns.response(200, "Available metrics retrieved") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get available evaluation metrics for the current framework. + """ + result = {} + for category in EvaluationCategory: + result[category.value] = EvaluationService.get_supported_metrics(category) + return {"metrics": result} + + +@console_ns.route("///evaluation/default-metrics") +class EvaluationDefaultMetricsApi(Resource): + @console_ns.doc( + "get_evaluation_default_metrics_with_nodes", + description=( + "List default metrics supported by the current evaluation framework with matching nodes " + "from the target's published workflow only (draft is ignored)." + ), + ) + @console_ns.response( + 200, + "Default metrics and node candidates for the published workflow", + evaluation_default_metrics_response_model, + ) + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target( + target=target, + target_type=target_type, + ) + return {"default_metrics": [m.model_dump() for m in default_metrics]} + + +@console_ns.route("///evaluation/node-info") +class EvaluationNodeInfoApi(Resource): + @console_ns.doc("get_evaluation_node_info") + @console_ns.response(200, "Node info grouped by metric") + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def post(self, target: Union[App, CustomizedSnippet], target_type: str): + """Return workflow/snippet node info grouped by requested metrics. + + Request body (JSON): + - metrics: list[str] | None – metric names to query; omit or pass + an empty list to get all nodes under key ``"all"``. + + Response: + ``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}`` + """ + body = request.get_json(silent=True) or {} + metrics: list[str] | None = body.get("metrics") or None + + result = EvaluationService.get_nodes_for_metrics( + target=target, + target_type=target_type, + metrics=metrics, + ) + return result + + +@console_ns.route("/evaluation/available-metrics") +class EvaluationAvailableMetricsApi(Resource): + @console_ns.doc("get_available_evaluation_metrics") + @console_ns.response(200, "Available metrics list") + @setup_required + @login_required + @account_initialization_required + def get(self): + """Return the centrally-defined list of evaluation metrics.""" + return {"metrics": EvaluationService.get_available_metrics()} + + +@console_ns.route("///evaluation/files/") +class EvaluationFileDownloadApi(Resource): + @console_ns.doc("download_evaluation_file") + @console_ns.response(200, "File download URL generated successfully") + @console_ns.response(404, "Target or file not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str): + """ + Download evaluation test file or result file. + + Looks up the specified file, verifies it belongs to the same tenant, + and returns file info and download URL. + """ + file_id = str(file_id) + _, current_tenant_id = current_account_with_tenant() + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(UploadFile).where( + UploadFile.id == file_id, + UploadFile.tenant_id == current_tenant_id, + ) + upload_file = session.execute(stmt).scalar_one_or_none() + + if not upload_file: + raise NotFound("File not found") + + download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "mime_type": upload_file.mime_type, + "created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None, + "download_url": download_url, + } + + +@console_ns.route("///evaluation/version") +class EvaluationVersionApi(Resource): + @console_ns.doc("get_evaluation_version_detail") + @console_ns.expect(console_ns.models.get(VersionQuery.__name__)) + @console_ns.response(200, "Version details retrieved successfully") + @console_ns.response(404, "Target or version not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get evaluation target version details. + + Returns the workflow graph for the specified version. + """ + version = request.args.get("version") + + if not version: + return {"message": "version parameter is required"}, 400 + + graph = {} + if target_type == "snippets" and isinstance(target, CustomizedSnippet): + graph = target.graph_dict + + return { + "graph": graph, + } + + +@console_ns.route("/workspaces/current/available-evaluation-workflows") +class AvailableEvaluationWorkflowsApi(Resource): + @console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) + @console_ns.doc("list_available_evaluation_workflows") + @console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)") + @console_ns.response( + 200, + "Available evaluation workflows retrieved", + available_evaluation_workflow_pagination_model, + ) + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def get(self): + """List published evaluation-type workflows for the current tenant (cross-app).""" + current_user, current_tenant_id = current_account_with_tenant() + + args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + page = args.page + limit = args.limit + user_id = args.user_id + named_only = args.named_only + keyword = args.keyword + + if user_id and user_id != current_user.id: + raise Forbidden() + + workflow_service = WorkflowService() + with Session(db.engine) as session: + workflows, has_more = workflow_service.list_published_evaluation_workflows( + session=session, + tenant_id=current_tenant_id, + page=page, + limit=limit, + user_id=user_id, + named_only=named_only, + keyword=keyword, + ) + + app_ids = {w.app_id for w in workflows} + if app_ids: + apps = session.scalars(select(App).where(App.id.in_(app_ids))).all() + app_names = {a.id: a.name for a in apps} + else: + app_names = {} + + items = [] + for wf in workflows: + items.append( + { + "id": wf.id, + "app_id": wf.app_id, + "app_name": app_names.get(wf.app_id, ""), + "type": wf.type.value, + "version": wf.version, + "marked_name": wf.marked_name, + "marked_comment": wf.marked_comment, + "hash": wf.unique_hash, + "created_by": wf.created_by_account, + "created_at": wf.created_at, + "updated_by": wf.updated_by_account, + "updated_at": wf.updated_at, + } + ) + + return ( + marshal( + {"items": items, "page": page, "limit": limit, "has_more": has_more}, + available_evaluation_workflow_pagination_fields, + ), + 200, + ) + + +@console_ns.route("/workspaces/current/evaluation-workflows//associated-targets") +class EvaluationWorkflowAssociatedTargetsApi(Resource): + @console_ns.doc("list_evaluation_workflow_associated_targets") + @console_ns.doc( + description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics" + ) + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def get(self, workflow_id: str): + """Return all evaluation targets that reference this workflow as customized metrics.""" + _, current_tenant_id = current_account_with_tenant() + + with Session(db.engine) as session: + configs = EvaluationService.list_targets_by_customized_workflow( + session=session, + tenant_id=current_tenant_id, + customized_workflow_id=workflow_id, + ) + + target_ids_by_type: dict[str, list[str]] = {} + for cfg in configs: + target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id) + + app_names: dict[str, str] = {} + if "app" in target_ids_by_type: + apps = session.scalars(select(App).where(App.id.in_(target_ids_by_type["app"]))).all() + app_names = {a.id: a.name for a in apps} + + snippet_names: dict[str, str] = {} + if "snippets" in target_ids_by_type: + snippets = session.scalars( + select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"])) + ).all() + snippet_names = {s.id: s.name for s in snippets} + + dataset_names: dict[str, str] = {} + if "knowledge_base" in target_ids_by_type: + datasets = session.scalars( + select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"])) + ).all() + dataset_names = {d.id: d.name for d in datasets} + + items = [] + for cfg in configs: + name = "" + if cfg.target_type == "app": + name = app_names.get(cfg.target_id, "") + elif cfg.target_type == "snippets": + name = snippet_names.get(cfg.target_id, "") + elif cfg.target_type == "knowledge_base": + name = dataset_names.get(cfg.target_id, "") + + items.append( + { + "target_type": cfg.target_type, + "target_id": cfg.target_id, + "target_name": name, + } + ) + + return {"items": items}, 200 + + +# ---- Serialization Helpers ---- + + +def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]: + return { + "id": run.id, + "tenant_id": run.tenant_id, + "target_type": run.target_type, + "target_id": run.target_id, + "evaluation_config_id": run.evaluation_config_id, + "status": run.status, + "dataset_file_id": run.dataset_file_id, + "result_file_id": run.result_file_id, + "total_items": run.total_items, + "completed_items": run.completed_items, + "failed_items": run.failed_items, + "progress": run.progress, + "metrics_summary": run.metrics_summary_dict, + "error": run.error, + "created_by": run.created_by, + "started_at": int(run.started_at.timestamp()) if run.started_at else None, + "completed_at": int(run.completed_at.timestamp()) if run.completed_at else None, + "created_at": int(run.created_at.timestamp()) if run.created_at else None, + } + + +def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]: + return { + "id": item.id, + "item_index": item.item_index, + "inputs": item.inputs_dict, + "expected_output": item.expected_output, + "actual_output": item.actual_output, + "metrics": item.metrics_list, + "judgment": item.judgment_dict, + "metadata": item.metadata_dict, + "error": item.error, + "overall_score": item.overall_score, + } diff --git a/api/controllers/console/snippets/payloads.py b/api/controllers/console/snippets/payloads.py new file mode 100644 index 0000000000..980506ccc4 --- /dev/null +++ b/api/controllers/console/snippets/payloads.py @@ -0,0 +1,135 @@ +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator + + +class SnippetListQuery(BaseModel): + """Query parameters for listing snippets.""" + + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=20, ge=1, le=100) + keyword: str | None = None + is_published: bool | None = Field(default=None, description="Filter by published status") + creators: list[str] | None = Field(default=None, description="Filter by creator account IDs") + + @field_validator("creators", mode="before") + @classmethod + def parse_creators(cls, value: object) -> list[str] | None: + """Normalize creators filter from query string or list input.""" + if value is None: + return None + if isinstance(value, str): + return [creator.strip() for creator in value.split(",") if creator.strip()] or None + if isinstance(value, list): + return [str(creator).strip() for creator in value if str(creator).strip()] or None + return None + + +class IconInfo(BaseModel): + """Icon information model.""" + + icon: str | None = None + icon_type: Literal["emoji", "image"] | None = None + icon_background: str | None = None + icon_url: str | None = None + + +class InputFieldDefinition(BaseModel): + """Input field definition for snippet parameters.""" + + default: str | None = None + hint: bool | None = None + label: str | None = None + max_length: int | None = None + options: list[str] | None = None + placeholder: str | None = None + required: bool | None = None + type: str | None = None # e.g., "text-input" + + +class CreateSnippetPayload(BaseModel): + """Payload for creating a new snippet.""" + + name: str = Field(..., min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=2000) + type: Literal["node", "group"] = "node" + icon_info: IconInfo | None = None + graph: dict[str, Any] | None = None + input_fields: list[InputFieldDefinition] | None = Field(default_factory=list) + + +class UpdateSnippetPayload(BaseModel): + """Payload for updating a snippet.""" + + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=2000) + icon_info: IconInfo | None = None + + +class SnippetDraftSyncPayload(BaseModel): + """Payload for syncing snippet draft workflow.""" + + graph: dict[str, Any] + hash: str | None = None + conversation_variables: list[dict[str, Any]] | None = Field( + default=None, + description="Ignored. Snippet workflows do not persist conversation variables.", + ) + input_fields: list[dict[str, Any]] | None = None + + +class WorkflowRunQuery(BaseModel): + """Query parameters for workflow runs.""" + + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SnippetDraftRunPayload(BaseModel): + """Payload for running snippet draft workflow.""" + + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + + +class SnippetDraftNodeRunPayload(BaseModel): + """Payload for running a single node in snippet draft workflow.""" + + inputs: dict[str, Any] + query: str = "" + files: list[dict[str, Any]] | None = None + + +class SnippetIterationNodeRunPayload(BaseModel): + """Payload for running an iteration node in snippet draft workflow.""" + + inputs: dict[str, Any] | None = None + + +class SnippetLoopNodeRunPayload(BaseModel): + """Payload for running a loop node in snippet draft workflow.""" + + inputs: dict[str, Any] | None = None + + +class PublishWorkflowPayload(BaseModel): + """Payload for publishing snippet workflow.""" + + knowledge_base_setting: dict[str, Any] | None = None + + +class SnippetImportPayload(BaseModel): + """Payload for importing snippet from DSL.""" + + mode: str = Field(..., description="Import mode: yaml-content or yaml-url") + yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)") + yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)") + name: str | None = Field(default=None, description="Override snippet name") + description: str | None = Field(default=None, description="Override snippet description") + snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)") + + +class IncludeSecretQuery(BaseModel): + """Query parameter for including secret variables in export.""" + + include_secret: str = Field(default="false", description="Whether to include secret variables") diff --git a/api/controllers/console/snippets/snippet_workflow.py b/api/controllers/console/snippets/snippet_workflow.py new file mode 100644 index 0000000000..d4b3ec5e4e --- /dev/null +++ b/api/controllers/console/snippets/snippet_workflow.py @@ -0,0 +1,534 @@ +import logging +from collections.abc import Callable +from functools import wraps +from typing import ParamSpec, TypeVar + +from flask import request +from flask_restx import Resource, marshal_with +from sqlalchemy.orm import Session +from werkzeug.exceptions import InternalServerError, NotFound + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.app.workflow import workflow_model +from controllers.console.app.workflow_run import ( + workflow_run_detail_model, + workflow_run_node_execution_list_model, + workflow_run_node_execution_model, + workflow_run_pagination_model, +) +from controllers.console.snippets.payloads import ( + PublishWorkflowPayload, + SnippetDraftNodeRunPayload, + SnippetDraftRunPayload, + SnippetDraftSyncPayload, + SnippetIterationNodeRunPayload, + SnippetLoopNodeRunPayload, + WorkflowRunQuery, +) +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, +) +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from libs import helper +from libs.helper import TimestampField +from libs.login import current_account_with_tenant, login_required +from models.snippet import CustomizedSnippet +from services.errors.app import WorkflowHashNotEqualError +from services.snippet_generate_service import SnippetGenerateService +from services.snippet_service import SnippetService + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + +# Register Pydantic models with Swagger +register_schema_models( + console_ns, + SnippetDraftSyncPayload, + SnippetDraftNodeRunPayload, + SnippetDraftRunPayload, + SnippetIterationNodeRunPayload, + SnippetLoopNodeRunPayload, + WorkflowRunQuery, + PublishWorkflowPayload, +) + + +class SnippetNotFoundError(Exception): + """Snippet not found error.""" + + pass + + +def get_snippet(view_func: Callable[P, R]): + """Decorator to fetch and validate snippet access.""" + + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + if not kwargs.get("snippet_id"): + raise ValueError("missing snippet_id in path parameters") + + _, current_tenant_id = current_account_with_tenant() + + snippet_id = str(kwargs.get("snippet_id")) + del kwargs["snippet_id"] + + snippet = SnippetService.get_snippet_by_id( + snippet_id=snippet_id, + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + kwargs["snippet"] = snippet + + return view_func(*args, **kwargs) + + return decorated_view + + +@console_ns.route("/snippets//workflows/draft") +class SnippetDraftWorkflowApi(Resource): + @console_ns.doc("get_snippet_draft_workflow") + @console_ns.response(200, "Draft workflow retrieved successfully", workflow_model) + @console_ns.response(404, "Snippet or draft workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + @marshal_with(workflow_model) + def get(self, snippet: CustomizedSnippet): + """Get draft workflow for snippet.""" + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + + if not workflow: + raise DraftWorkflowNotExist() + + db.session.expunge(workflow) + workflow.conversation_variables = [] + return workflow + + @console_ns.doc("sync_snippet_draft_workflow") + @console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__)) + @console_ns.response(200, "Draft workflow synced successfully") + @console_ns.response(400, "Hash mismatch") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet): + """Sync draft workflow for snippet.""" + current_user, _ = current_account_with_tenant() + + payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {}) + + try: + snippet_service = SnippetService() + workflow = snippet_service.sync_draft_workflow( + snippet=snippet, + graph=payload.graph, + unique_hash=payload.hash, + account=current_user, + input_fields=payload.input_fields, + ) + except WorkflowHashNotEqualError: + raise DraftWorkflowNotSync() + except ValueError as e: + return {"message": str(e)}, 400 + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + +@console_ns.route("/snippets//workflows/draft/config") +class SnippetDraftConfigApi(Resource): + @console_ns.doc("get_snippet_draft_config") + @console_ns.response(200, "Draft config retrieved successfully") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def get(self, snippet: CustomizedSnippet): + """Get snippet draft workflow configuration limits.""" + return { + "parallel_depth_limit": 3, + } + + +@console_ns.route("/snippets//workflows/publish") +class SnippetPublishedWorkflowApi(Resource): + @console_ns.doc("get_snippet_published_workflow") + @console_ns.response(200, "Published workflow retrieved successfully", workflow_model) + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + @marshal_with(workflow_model) + def get(self, snippet: CustomizedSnippet): + """Get published workflow for snippet.""" + if not snippet.is_published: + return None + + snippet_service = SnippetService() + workflow = snippet_service.get_published_workflow(snippet=snippet) + + return workflow + + @console_ns.doc("publish_snippet_workflow") + @console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__)) + @console_ns.response(200, "Workflow published successfully") + @console_ns.response(400, "No draft workflow found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet): + """Publish snippet workflow.""" + current_user, _ = current_account_with_tenant() + snippet_service = SnippetService() + + with Session(db.engine) as session: + snippet = session.merge(snippet) + try: + workflow = snippet_service.publish_workflow( + session=session, + snippet=snippet, + account=current_user, + ) + workflow_created_at = TimestampField().format(workflow.created_at) + session.commit() + except ValueError as e: + return {"message": str(e)}, 400 + + return { + "result": "success", + "created_at": workflow_created_at, + } + + +@console_ns.route("/snippets//workflows/default-workflow-block-configs") +class SnippetDefaultBlockConfigsApi(Resource): + @console_ns.doc("get_snippet_default_block_configs") + @console_ns.response(200, "Default block configs retrieved successfully") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def get(self, snippet: CustomizedSnippet): + """Get default block configurations for snippet workflow.""" + snippet_service = SnippetService() + return snippet_service.get_default_block_configs() + + +@console_ns.route("/snippets//workflow-runs") +class SnippetWorkflowRunsApi(Resource): + @console_ns.doc("list_snippet_workflow_runs") + @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_pagination_model) + def get(self, snippet: CustomizedSnippet): + """List workflow runs for snippet.""" + query = WorkflowRunQuery.model_validate( + { + "last_id": request.args.get("last_id"), + "limit": request.args.get("limit", type=int, default=20), + } + ) + args = { + "last_id": query.last_id, + "limit": query.limit, + } + + snippet_service = SnippetService() + result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args) + + return result + + +@console_ns.route("/snippets//workflow-runs/") +class SnippetWorkflowRunDetailApi(Resource): + @console_ns.doc("get_snippet_workflow_run_detail") + @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model) + @console_ns.response(404, "Workflow run not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_detail_model) + def get(self, snippet: CustomizedSnippet, run_id): + """Get workflow run detail for snippet.""" + run_id = str(run_id) + + snippet_service = SnippetService() + workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id) + + if not workflow_run: + raise NotFound("Workflow run not found") + + return workflow_run + + +@console_ns.route("/snippets//workflow-runs//node-executions") +class SnippetWorkflowRunNodeExecutionsApi(Resource): + @console_ns.doc("list_snippet_workflow_run_node_executions") + @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model) + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_node_execution_list_model) + def get(self, snippet: CustomizedSnippet, run_id): + """List node executions for a workflow run.""" + run_id = str(run_id) + + snippet_service = SnippetService() + node_executions = snippet_service.get_snippet_workflow_run_node_executions( + snippet=snippet, + run_id=run_id, + ) + + return {"data": node_executions} + + +@console_ns.route("/snippets//workflows/draft/nodes//run") +class SnippetDraftNodeRunApi(Resource): + @console_ns.doc("run_snippet_draft_node") + @console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)") + @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__)) + @console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model) + @console_ns.response(404, "Snippet or draft workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_node_execution_model) + @edit_permission_required + def post(self, snippet: CustomizedSnippet, node_id: str): + """ + Run a single node in snippet draft workflow. + + Executes a specific node with provided inputs for single-step debugging. + Returns the node execution result including status, outputs, and timing. + """ + current_user, _ = current_account_with_tenant() + payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {}) + + user_inputs = payload.inputs + + # Get draft workflow for file parsing + snippet_service = SnippetService() + draft_workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not draft_workflow: + raise NotFound("Draft workflow not found") + + files = SnippetGenerateService.parse_files(draft_workflow, payload.files) + + workflow_node_execution = SnippetGenerateService.run_draft_node( + snippet=snippet, + node_id=node_id, + user_inputs=user_inputs, + account=current_user, + query=payload.query, + files=files, + ) + + return workflow_node_execution + + +@console_ns.route("/snippets//workflows/draft/nodes//last-run") +class SnippetDraftNodeLastRunApi(Resource): + @console_ns.doc("get_snippet_draft_node_last_run") + @console_ns.doc(description="Get last run result for a node in snippet draft workflow") + @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"}) + @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model) + @console_ns.response(404, "Snippet, draft workflow, or node last run not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_node_execution_model) + def get(self, snippet: CustomizedSnippet, node_id: str): + """ + Get the last run result for a specific node in snippet draft workflow. + + Returns the most recent execution record for the given node, + including status, inputs, outputs, and timing information. + """ + snippet_service = SnippetService() + draft_workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not draft_workflow: + raise NotFound("Draft workflow not found") + + node_exec = snippet_service.get_snippet_node_last_run( + snippet=snippet, + workflow=draft_workflow, + node_id=node_id, + ) + if node_exec is None: + raise NotFound("Node last run not found") + + return node_exec + + +@console_ns.route("/snippets//workflows/draft/iteration/nodes//run") +class SnippetDraftRunIterationNodeApi(Resource): + @console_ns.doc("run_snippet_draft_iteration_node") + @console_ns.doc(description="Run draft workflow iteration node for snippet") + @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__)) + @console_ns.response(200, "Iteration node run started successfully (SSE stream)") + @console_ns.response(404, "Snippet or draft workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet, node_id: str): + """ + Run a draft workflow iteration node for snippet. + + Iteration nodes execute their internal sub-graph multiple times over an input list. + Returns an SSE event stream with iteration progress and results. + """ + current_user, _ = current_account_with_tenant() + args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) + + try: + response = SnippetGenerateService.generate_single_iteration( + snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True + ) + + return helper.compact_generate_response(response) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +@console_ns.route("/snippets//workflows/draft/loop/nodes//run") +class SnippetDraftRunLoopNodeApi(Resource): + @console_ns.doc("run_snippet_draft_loop_node") + @console_ns.doc(description="Run draft workflow loop node for snippet") + @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"}) + @console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__)) + @console_ns.response(200, "Loop node run started successfully (SSE stream)") + @console_ns.response(404, "Snippet or draft workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet, node_id: str): + """ + Run a draft workflow loop node for snippet. + + Loop nodes execute their internal sub-graph repeatedly until a condition is met. + Returns an SSE event stream with loop progress and results. + """ + current_user, _ = current_account_with_tenant() + args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {}) + + try: + response = SnippetGenerateService.generate_single_loop( + snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True + ) + + return helper.compact_generate_response(response) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +@console_ns.route("/snippets//workflows/draft/run") +class SnippetDraftWorkflowRunApi(Resource): + @console_ns.doc("run_snippet_draft_workflow") + @console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__)) + @console_ns.response(200, "Draft workflow run started successfully (SSE stream)") + @console_ns.response(404, "Snippet or draft workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet): + """ + Run draft workflow for snippet. + + Executes the snippet's draft workflow with the provided inputs + and returns an SSE event stream with execution progress and results. + """ + current_user, _ = current_account_with_tenant() + + payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) + + try: + response = SnippetGenerateService.generate( + snippet=snippet, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + return helper.compact_generate_response(response) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +@console_ns.route("/snippets//workflow-runs/tasks//stop") +class SnippetWorkflowTaskStopApi(Resource): + @console_ns.doc("stop_snippet_workflow_task") + @console_ns.response(200, "Task stopped successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet, task_id: str): + """ + Stop a running snippet workflow task. + + Uses both the legacy stop flag mechanism and the graph engine + command channel for backward compatibility. + """ + # Stop using both mechanisms for backward compatibility + # Legacy stop flag mechanism (without user check) + AppQueueManager.set_stop_flag_no_user_check(task_id) + + # New graph engine command channel mechanism + GraphEngineManager(redis_client).send_stop_command(task_id) + + return {"result": "success"} diff --git a/api/controllers/console/snippets/snippet_workflow_draft_variable.py b/api/controllers/console/snippets/snippet_workflow_draft_variable.py new file mode 100644 index 0000000000..ce3f5cef52 --- /dev/null +++ b/api/controllers/console/snippets/snippet_workflow_draft_variable.py @@ -0,0 +1,319 @@ +""" +Snippet draft workflow variable APIs. + +Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope, +using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution). + +Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables +(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset +reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity. +Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`. +""" + +from collections.abc import Callable +from functools import wraps +from typing import Any, ParamSpec, TypeVar + +from flask import Response, request +from flask_restx import Resource, marshal, marshal_with +from sqlalchemy.orm import Session + +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.workflow_draft_variable import ( + WorkflowDraftVariableListQuery, + WorkflowDraftVariableUpdatePayload, + _ensure_variable_access, + _file_access_controller, + validate_node_id, + workflow_draft_variable_list_model, + workflow_draft_variable_list_without_value_model, + workflow_draft_variable_model, +) +from controllers.console.snippets.snippet_workflow import get_snippet +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from extensions.ext_database import db +from factories.file_factory import build_from_mapping, build_from_mappings +from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType +from libs.login import current_user, login_required +from models.snippet import CustomizedSnippet +from models.workflow import WorkflowDraftVariable +from services.snippet_service import SnippetService +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService + +P = ParamSpec("P") +R = TypeVar("R") + +_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset( + {SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID} +) + + +def _ensure_snippet_draft_variable_row_allowed( + *, + variable: WorkflowDraftVariable, + variable_id: str, +) -> None: + """Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found.""" + if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + +def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]: + """Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs).""" + + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return f(*args, **kwargs) + + return wrapper + + +@console_ns.route("/snippets//workflows/draft/variables") +class SnippetWorkflowVariableCollectionApi(Resource): + @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__]) + @console_ns.doc("get_snippet_workflow_variables") + @console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)") + @console_ns.response( + 200, + "Workflow variables retrieved successfully", + workflow_draft_variable_list_without_value_model, + ) + @_snippet_draft_var_prerequisite + @marshal_with(workflow_draft_variable_list_without_value_model) + def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: + args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + snippet_service = SnippetService() + if snippet_service.get_draft_workflow(snippet=snippet) is None: + raise DraftWorkflowNotExist() + + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService(session=session) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=snippet.id, + page=args.page, + limit=args.limit, + user_id=current_user.id, + exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS, + ) + + return workflow_vars + + @console_ns.doc("delete_snippet_workflow_variables") + @console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)") + @console_ns.response(204, "Workflow variables deleted successfully") + @_snippet_draft_var_prerequisite + def delete(self, snippet: CustomizedSnippet) -> Response: + draft_var_srv = WorkflowDraftVariableService(session=db.session()) + draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id) + db.session.commit() + return Response("", 204) + + +@console_ns.route("/snippets//workflows/draft/nodes//variables") +class SnippetNodeVariableCollectionApi(Resource): + @console_ns.doc("get_snippet_node_variables") + @console_ns.doc(description="Get variables for a specific node (snippet draft workflow)") + @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model) + @_snippet_draft_var_prerequisite + @marshal_with(workflow_draft_variable_list_model) + def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList: + validate_node_id(node_id) + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService(session=session) + node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id) + + return node_vars + + @console_ns.doc("delete_snippet_node_variables") + @console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)") + @console_ns.response(204, "Node variables deleted successfully") + @_snippet_draft_var_prerequisite + def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response: + validate_node_id(node_id) + srv = WorkflowDraftVariableService(db.session()) + srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id) + db.session.commit() + return Response("", 204) + + +@console_ns.route("/snippets//workflows/draft/variables/") +class SnippetVariableApi(Resource): + @console_ns.doc("get_snippet_workflow_variable") + @console_ns.doc(description="Get a specific draft workflow variable (snippet scope)") + @console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model) + @console_ns.response(404, "Variable not found") + @_snippet_draft_var_prerequisite + @marshal_with(workflow_draft_variable_model) + def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: + draft_var_srv = WorkflowDraftVariableService(session=db.session()) + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=snippet.id, + variable_id=variable_id, + ) + _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) + return variable + + @console_ns.doc("update_snippet_workflow_variable") + @console_ns.doc(description="Update a draft workflow variable (snippet scope)") + @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__]) + @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) + @console_ns.response(404, "Variable not found") + @_snippet_draft_var_prerequisite + @marshal_with(workflow_draft_variable_model) + def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: + draft_var_srv = WorkflowDraftVariableService(session=db.session()) + args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) + + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=snippet.id, + variable_id=variable_id, + ) + _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) + + new_name = args_model.name + raw_value = args_model.value + if new_name is None and raw_value is None: + return variable + + new_value = None + if raw_value is not None: + if variable.value_type == SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=snippet.tenant_id, + access_controller=_file_access_controller, + ) + elif variable.value_type == SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=snippet.tenant_id, + access_controller=_file_access_controller, + ) + new_value = build_segment_with_type(variable.value_type, raw_value) + draft_var_srv.update_variable(variable, name=new_name, value=new_value) + db.session.commit() + return variable + + @console_ns.doc("delete_snippet_workflow_variable") + @console_ns.doc(description="Delete a draft workflow variable (snippet scope)") + @console_ns.response(204, "Variable deleted successfully") + @console_ns.response(404, "Variable not found") + @_snippet_draft_var_prerequisite + def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response: + draft_var_srv = WorkflowDraftVariableService(session=db.session()) + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=snippet.id, + variable_id=variable_id, + ) + _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) + draft_var_srv.delete_variable(variable) + db.session.commit() + return Response("", 204) + + +@console_ns.route("/snippets//workflows/draft/variables//reset") +class SnippetVariableResetApi(Resource): + @console_ns.doc("reset_snippet_workflow_variable") + @console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)") + @console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model) + @console_ns.response(204, "Variable reset (no content)") + @console_ns.response(404, "Variable not found") + @_snippet_draft_var_prerequisite + def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any: + draft_var_srv = WorkflowDraftVariableService(session=db.session()) + snippet_service = SnippetService() + draft_workflow = snippet_service.get_draft_workflow(snippet=snippet) + if draft_workflow is None: + raise NotFoundError( + f"Draft workflow not found, snippet_id={snippet.id}", + ) + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=snippet.id, + variable_id=variable_id, + ) + _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) + + resetted = draft_var_srv.reset_variable(draft_workflow, variable) + db.session.commit() + if resetted is None: + return Response("", 204) + return marshal(resetted, workflow_draft_variable_model) + + +@console_ns.route("/snippets//workflows/draft/conversation-variables") +class SnippetConversationVariableCollectionApi(Resource): + @console_ns.doc("get_snippet_conversation_variables") + @console_ns.doc( + description="Conversation variables are not used in snippet workflows; returns an empty list for API parity" + ) + @console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model) + @_snippet_draft_var_prerequisite + @marshal_with(workflow_draft_variable_list_model) + def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: + return WorkflowDraftVariableList(variables=[]) + + +@console_ns.route("/snippets//workflows/draft/system-variables") +class SnippetSystemVariableCollectionApi(Resource): + @console_ns.doc("get_snippet_system_variables") + @console_ns.doc( + description="System variables are not used in snippet workflows; returns an empty list for API parity" + ) + @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model) + @_snippet_draft_var_prerequisite + @marshal_with(workflow_draft_variable_list_model) + def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: + return WorkflowDraftVariableList(variables=[]) + + +@console_ns.route("/snippets//workflows/draft/environment-variables") +class SnippetEnvironmentVariableCollectionApi(Resource): + @console_ns.doc("get_snippet_environment_variables") + @console_ns.doc(description="Get environment variables from snippet draft workflow graph") + @console_ns.response(200, "Environment variables retrieved successfully") + @console_ns.response(404, "Draft workflow not found") + @_snippet_draft_var_prerequisite + def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]: + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if workflow is None: + raise DraftWorkflowNotExist() + + env_vars_list: list[dict[str, Any]] = [] + for v in workflow.environment_variables: + env_vars_list.append( + { + "id": v.id, + "type": "env", + "name": v.name, + "description": v.description, + "selector": v.selector, + "value_type": v.value_type.exposed_type().value, + "value": v.value, + "edited": False, + "visible": True, + "editable": True, + } + ) + + return {"items": env_vars_list} diff --git a/api/controllers/console/workspace/snippets.py b/api/controllers/console/workspace/snippets.py new file mode 100644 index 0000000000..13bfc37f71 --- /dev/null +++ b/api/controllers/console/workspace/snippets.py @@ -0,0 +1,380 @@ +import logging +from urllib.parse import quote + +from flask import Response, request +from flask_restx import Resource, marshal +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.snippets.payloads import ( + CreateSnippetPayload, + IncludeSecretQuery, + SnippetImportPayload, + SnippetListQuery, + UpdateSnippetPayload, +) +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, +) +from extensions.ext_database import db +from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields +from libs.login import current_account_with_tenant, login_required +from models.snippet import SnippetType +from services.app_dsl_service import ImportStatus +from services.snippet_dsl_service import SnippetDslService +from services.snippet_service import SnippetService + +logger = logging.getLogger(__name__) + +# Register Pydantic models with Swagger +register_schema_models( + console_ns, + SnippetListQuery, + CreateSnippetPayload, + UpdateSnippetPayload, + SnippetImportPayload, + IncludeSecretQuery, +) + +# Create namespace models for marshaling +snippet_model = console_ns.model("Snippet", snippet_fields) +snippet_list_model = console_ns.model("SnippetList", snippet_list_fields) +snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields) + + +@console_ns.route("/workspaces/current/customized-snippets") +class CustomizedSnippetsApi(Resource): + @console_ns.doc("list_customized_snippets") + @console_ns.expect(console_ns.models.get(SnippetListQuery.__name__)) + @console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model) + @setup_required + @login_required + @account_initialization_required + def get(self): + """List customized snippets with pagination and search.""" + _, current_tenant_id = current_account_with_tenant() + + query_params = request.args.to_dict() + query = SnippetListQuery.model_validate(query_params) + + snippets, total, has_more = SnippetService.get_snippets( + tenant_id=current_tenant_id, + page=query.page, + limit=query.limit, + keyword=query.keyword, + is_published=query.is_published, + creators=query.creators, + ) + + return { + "data": marshal(snippets, snippet_list_fields), + "page": query.page, + "limit": query.limit, + "total": total, + "has_more": has_more, + }, 200 + + @console_ns.doc("create_customized_snippet") + @console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__)) + @console_ns.response(201, "Snippet created successfully", snippet_model) + @console_ns.response(400, "Invalid request or name already exists") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self): + """Create a new customized snippet.""" + current_user, current_tenant_id = current_account_with_tenant() + + payload = CreateSnippetPayload.model_validate(console_ns.payload or {}) + + try: + snippet_type = SnippetType(payload.type) + except ValueError: + snippet_type = SnippetType.NODE + + try: + snippet = SnippetService.create_snippet( + tenant_id=current_tenant_id, + name=payload.name, + description=payload.description, + snippet_type=snippet_type, + icon_info=payload.icon_info.model_dump() if payload.icon_info else None, + input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None, + account=current_user, + ) + except ValueError as e: + return {"message": str(e)}, 400 + + return marshal(snippet, snippet_fields), 201 + + +@console_ns.route("/workspaces/current/customized-snippets/") +class CustomizedSnippetDetailApi(Resource): + @console_ns.doc("get_customized_snippet") + @console_ns.response(200, "Snippet retrieved successfully", snippet_model) + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + def get(self, snippet_id: str): + """Get customized snippet details.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + return marshal(snippet, snippet_fields), 200 + + @console_ns.doc("update_customized_snippet") + @console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__)) + @console_ns.response(200, "Snippet updated successfully", snippet_model) + @console_ns.response(400, "Invalid request or name already exists") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def patch(self, snippet_id: str): + """Update customized snippet.""" + current_user, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + payload = UpdateSnippetPayload.model_validate(console_ns.payload or {}) + update_data = payload.model_dump(exclude_unset=True) + + if "icon_info" in update_data and update_data["icon_info"] is not None: + update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None + + if not update_data: + return {"message": "No valid fields to update"}, 400 + + try: + with Session(db.engine, expire_on_commit=False) as session: + snippet = session.merge(snippet) + snippet = SnippetService.update_snippet( + session=session, + snippet=snippet, + account_id=current_user.id, + data=update_data, + ) + session.commit() + except ValueError as e: + return {"message": str(e)}, 400 + + return marshal(snippet, snippet_fields), 200 + + @console_ns.doc("delete_customized_snippet") + @console_ns.response(204, "Snippet deleted successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def delete(self, snippet_id: str): + """Delete customized snippet.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + with Session(db.engine) as session: + snippet = session.merge(snippet) + SnippetService.delete_snippet( + session=session, + snippet=snippet, + ) + session.commit() + + return "", 204 + + +@console_ns.route("/workspaces/current/customized-snippets//export") +class CustomizedSnippetExportApi(Resource): + @console_ns.doc("export_customized_snippet") + @console_ns.doc(description="Export snippet configuration as DSL") + @console_ns.doc(params={"snippet_id": "Snippet ID to export"}) + @console_ns.response(200, "Snippet exported successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def get(self, snippet_id: str): + """Export snippet as DSL.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + # Get include_secret parameter + query = IncludeSecretQuery.model_validate(request.args.to_dict()) + + with Session(db.engine) as session: + export_service = SnippetDslService(session) + result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true") + + # Set filename with .snippet extension + filename = f"{snippet.name}.snippet" + encoded_filename = quote(filename) + + response = Response( + result, + mimetype="application/x-yaml", + ) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + response.headers["Content-Type"] = "application/x-yaml" + + return response + + +@console_ns.route("/workspaces/current/customized-snippets/imports") +class CustomizedSnippetImportApi(Resource): + @console_ns.doc("import_customized_snippet") + @console_ns.doc(description="Import snippet from DSL") + @console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__)) + @console_ns.response(200, "Snippet imported successfully") + @console_ns.response(202, "Import pending confirmation") + @console_ns.response(400, "Import failed") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self): + """Import snippet from DSL.""" + current_user, _ = current_account_with_tenant() + payload = SnippetImportPayload.model_validate(console_ns.payload or {}) + + with Session(db.engine) as session: + import_service = SnippetDslService(session) + result = import_service.import_snippet( + account=current_user, + import_mode=payload.mode, + yaml_content=payload.yaml_content, + yaml_url=payload.yaml_url, + snippet_id=payload.snippet_id, + name=payload.name, + description=payload.description, + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +@console_ns.route("/workspaces/current/customized-snippets/imports//confirm") +class CustomizedSnippetImportConfirmApi(Resource): + @console_ns.doc("confirm_snippet_import") + @console_ns.doc(description="Confirm a pending snippet import") + @console_ns.doc(params={"import_id": "Import ID to confirm"}) + @console_ns.response(200, "Import confirmed successfully") + @console_ns.response(400, "Import failed") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self, import_id: str): + """Confirm a pending snippet import.""" + current_user, _ = current_account_with_tenant() + + with Session(db.engine) as session: + import_service = SnippetDslService(session) + result = import_service.confirm_import(import_id=import_id, account=current_user) + session.commit() + + if result.status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 + + +@console_ns.route("/workspaces/current/customized-snippets//check-dependencies") +class CustomizedSnippetCheckDependenciesApi(Resource): + @console_ns.doc("check_snippet_dependencies") + @console_ns.doc(description="Check dependencies for a snippet") + @console_ns.doc(params={"snippet_id": "Snippet ID"}) + @console_ns.response(200, "Dependencies checked successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def get(self, snippet_id: str): + """Check dependencies for a snippet.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + with Session(db.engine) as session: + import_service = SnippetDslService(session) + result = import_service.check_dependencies(snippet=snippet) + + return result.model_dump(mode="json"), 200 + + +@console_ns.route("/workspaces/current/customized-snippets//use-count/increment") +class CustomizedSnippetUseCountIncrementApi(Resource): + @console_ns.doc("increment_snippet_use_count") + @console_ns.doc(description="Increment snippet use count by 1") + @console_ns.doc(params={"snippet_id": "Snippet ID"}) + @console_ns.response(200, "Use count incremented successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self, snippet_id: str): + """Increment snippet use count when it is inserted into a workflow.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + with Session(db.engine) as session: + snippet = session.merge(snippet) + SnippetService.increment_use_count(session=session, snippet=snippet) + session.commit() + session.refresh(snippet) + + return {"result": "success", "use_count": snippet.use_count}, 200 diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6074e81d1e..ba070ffa94 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -14,7 +14,7 @@ from graphon.runtime import GraphRuntimeState from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -54,6 +54,25 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): + @staticmethod + def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow: + """Re-apply snippet virtual Start injection after worker reloads workflow from DB.""" + if workflow.type != "snippet": + return workflow + + from models.snippet import CustomizedSnippet + from services.snippet_generate_service import SnippetGenerateService + + snippet = session.scalar( + select(CustomizedSnippet).where( + CustomizedSnippet.id == workflow.app_id, + CustomizedSnippet.tenant_id == workflow.tenant_id, + ) + ) + if snippet is None: + return workflow + return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet) + @staticmethod def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool: return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY)) @@ -557,6 +576,8 @@ class WorkflowAppGenerator(BaseAppGenerator): if workflow is None: raise ValueError("Workflow not found") + workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow) + # Determine system_user_id based on invocation source is_external_api_call = application_generate_entity.invoke_from in { InvokeFrom.WEB_APP, diff --git a/api/core/evaluation/__init__.py b/api/core/evaluation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/evaluation/base_evaluation_instance.py b/api/core/evaluation/base_evaluation_instance.py new file mode 100644 index 0000000000..67fbf0374c --- /dev/null +++ b/api/core/evaluation/base_evaluation_instance.py @@ -0,0 +1,279 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any + +from core.evaluation.entities.evaluation_entity import ( + CustomizedMetrics, + EvaluationCategory, + EvaluationItemInput, + EvaluationItemResult, + EvaluationMetric, + NodeInfo, +) +from graphon.node_events.base import NodeRunResult + +logger = logging.getLogger(__name__) + + +class BaseEvaluationInstance(ABC): + """Abstract base class for evaluation framework adapters.""" + + @abstractmethod + def evaluate_llm( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Evaluate LLM outputs using the configured framework.""" + ... + + @abstractmethod + def evaluate_retrieval( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Evaluate retrieval quality using the configured framework.""" + ... + + @abstractmethod + def evaluate_agent( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Evaluate agent outputs using the configured framework.""" + ... + + @abstractmethod + def get_supported_metrics(self, category: EvaluationCategory) -> list[str]: + """Return the list of supported metric names for a given evaluation category.""" + ... + + def evaluate_with_customized_workflow( + self, + node_run_result_mapping_list: list[dict[str, NodeRunResult]], + customized_metrics: CustomizedMetrics, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Evaluate using a published workflow as the evaluator. + + The evaluator workflow's output variables are treated as metrics: + each output variable name becomes a metric name, and its value + becomes the score. + + Args: + node_run_result_mapping_list: One mapping per test-data item, + where each mapping is ``{node_id: NodeRunResult}`` from the + target execution. + customized_metrics: Contains ``evaluation_workflow_id`` (the + published evaluator workflow) and ``input_fields`` (value + sources for the evaluator's input variables). + tenant_id: Tenant scope. + + Returns: + A list of ``EvaluationItemResult`` with metrics extracted from + the evaluator workflow's output variables. + """ + from sqlalchemy.orm import Session + + from core.app.apps.workflow.app_generator import WorkflowAppGenerator + from core.app.entities.app_invoke_entities import InvokeFrom + from core.evaluation.runners import get_service_account_for_app + from models.engine import db + from models.model import App + from services.workflow_service import WorkflowService + + workflow_id = customized_metrics.evaluation_workflow_id + if not workflow_id: + raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator") + + # Load the evaluator workflow resources using a dedicated session + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first() + if not app: + raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}") + service_account = get_service_account_for_app(session, workflow_id) + + workflow_service = WorkflowService() + published_workflow = workflow_service.get_published_workflow(app_model=app) + if not published_workflow: + raise ValueError(f"No published workflow found for evaluation app {workflow_id}") + + eval_results: list[EvaluationItemResult] = [] + for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list): + try: + workflow_inputs = self._build_workflow_inputs( + customized_metrics.input_fields, + node_run_result_mapping, + ) + + generator = WorkflowAppGenerator() + response: Mapping[str, Any] = generator.generate( + app_model=app, + workflow=published_workflow, + user=service_account, + args={"inputs": workflow_inputs}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + call_depth=0, + ) + + metrics = self._extract_workflow_metrics(response, workflow_id) + eval_results.append( + EvaluationItemResult( + index=idx, + metrics=metrics, + ) + ) + except Exception: + logger.exception( + "Customized evaluator failed for item %d with workflow %s", + idx, + workflow_id, + ) + eval_results.append(EvaluationItemResult(index=idx)) + + return eval_results + + @staticmethod + def _build_workflow_inputs( + input_fields: dict[str, Any], + node_run_result_mapping: dict[str, NodeRunResult], + ) -> dict[str, Any]: + """Build customized workflow inputs by resolving value sources. + + Each entry in ``input_fields`` maps a workflow input variable name + to its value source, which can be: + + - **Constant**: a plain string without ``{{#…#}}`` used as-is. + - **Expression**: a string containing one or more + ``{{#node_id.output_key#}}`` selectors (same format as + ``VariableTemplateParser``) resolved from + ``node_run_result_mapping``. + + """ + from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX + + workflow_inputs: dict[str, Any] = {} + + for field_name, value_source in input_fields.items(): + if not isinstance(value_source, str): + # Non-string values (numbers, bools, dicts) are used directly. + workflow_inputs[field_name] = value_source + continue + + # Check if the entire value is a single expression. + full_match = VARIABLE_REGEX.fullmatch(value_source) + if full_match: + workflow_inputs[field_name] = resolve_variable_selector( + full_match.group(1), + node_run_result_mapping, + ) + elif VARIABLE_REGEX.search(value_source): + # Mixed template: interpolate all expressions as strings. + workflow_inputs[field_name] = VARIABLE_REGEX.sub( + lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)), + value_source, + ) + else: + # Plain constant — no expression markers. + workflow_inputs[field_name] = value_source + + return workflow_inputs + + @staticmethod + def _extract_workflow_metrics( + response: Mapping[str, object], + evaluation_workflow_id: str, + ) -> list[EvaluationMetric]: + """Extract evaluation metrics from workflow output variables. + + Each metric's ``node_info`` is set with *evaluation_workflow_id* as + the ``node_id``, so that judgment conditions can reference customized + metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``. + """ + metrics: list[EvaluationMetric] = [] + node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized") + + data = response.get("data") + if not isinstance(data, Mapping): + logger.warning("Unexpected workflow response format: missing 'data' dict") + return metrics + + outputs = data.get("outputs") + if not isinstance(outputs, dict): + logger.warning("Unexpected workflow response format: 'outputs' is not a dict") + return metrics + + for key, raw_value in outputs.items(): + if not isinstance(key, str): + continue + metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info)) + + return metrics + + +def resolve_variable_selector( + selector_raw: str, + node_run_result_mapping: dict[str, NodeRunResult], +) -> object: + """ + Resolve a ``#node_id.output_key#`` selector against node run results. + """ + # + cleaned = selector_raw.strip("#") + parts = cleaned.split(".") + + if len(parts) < 2: + logger.warning( + "Selector '%s' must have at least node_id.output_key", + selector_raw, + ) + return "" + + node_id = parts[0] + output_path = parts[1:] + + node_result = node_run_result_mapping.get(node_id) + if not node_result or not node_result.outputs: + logger.warning( + "Selector '%s': node '%s' not found or has no outputs", + selector_raw, + node_id, + ) + return "" + + # Traverse the output path to support nested keys. + current: object = node_result.outputs + for key in output_path: + if isinstance(current, Mapping): + next_val = current.get(key) + if next_val is None: + logger.warning( + "Selector '%s': key '%s' not found in node '%s' outputs", + selector_raw, + key, + node_id, + ) + return "" + current = next_val + else: + logger.warning( + "Selector '%s': cannot traverse into non-dict value at key '%s'", + selector_raw, + key, + ) + return "" + + return current if current is not None else "" diff --git a/api/core/evaluation/entities/__init__.py b/api/core/evaluation/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/evaluation/entities/config_entity.py b/api/core/evaluation/entities/config_entity.py new file mode 100644 index 0000000000..ef9f22f185 --- /dev/null +++ b/api/core/evaluation/entities/config_entity.py @@ -0,0 +1,27 @@ +from enum import StrEnum + +from pydantic import BaseModel + + +class EvaluationFrameworkEnum(StrEnum): + RAGAS = "ragas" + DEEPEVAL = "deepeval" + NONE = "none" + + +class BaseEvaluationConfig(BaseModel): + """Base configuration for evaluation frameworks.""" + + pass + + +class RagasConfig(BaseEvaluationConfig): + """RAGAS-specific configuration.""" + + pass + + +class DeepEvalConfig(BaseEvaluationConfig): + """DeepEval-specific configuration.""" + + pass diff --git a/api/core/evaluation/entities/evaluation_entity.py b/api/core/evaluation/entities/evaluation_entity.py new file mode 100644 index 0000000000..a87354b526 --- /dev/null +++ b/api/core/evaluation/entities/evaluation_entity.py @@ -0,0 +1,226 @@ +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult + + +class EvaluationCategory(StrEnum): + LLM = "llm" + RETRIEVAL = "knowledge_retrieval" + AGENT = "agent" + WORKFLOW = "workflow" + SNIPPET = "snippet" + KNOWLEDGE_BASE = "knowledge_base" + + +class EvaluationMetricName(StrEnum): + """Canonical metric names shared across all evaluation frameworks. + + Each framework maps these names to its own internal implementation. + A framework that does not support a given metric should log a warning + and skip it rather than raising an error. + + ── LLM / general text-quality metrics ────────────────────────────────── + FAITHFULNESS + Measures whether every claim in the model's response is grounded in + the provided retrieved context. A high score means the answer + contains no hallucinated content — each statement can be traced back + to a passage in the context. + Required fields: user_input, response, retrieved_contexts. + + ANSWER_RELEVANCY + Measures how well the model's response addresses the user's question. + A high score means the answer stays on-topic; a low score indicates + irrelevant content or a failure to answer the actual question. + Required fields: user_input, response. + + ANSWER_CORRECTNESS + Measures the factual accuracy and completeness of the model's answer + relative to a ground-truth reference. It combines semantic similarity + with key-fact coverage, so both meaning and content matter. + Required fields: user_input, response, reference (expected_output). + + SEMANTIC_SIMILARITY + Measures the cosine similarity between the model's response and the + reference answer in an embedding space. It evaluates whether the two + texts convey the same meaning, independent of factual correctness. + Required fields: response, reference (expected_output). + + ── Retrieval-quality metrics ──────────────────────────────────────────── + CONTEXT_PRECISION + Measures the proportion of retrieved context chunks that are actually + relevant to the question (precision). A high score means the retrieval + pipeline returns little noise. + Required fields: user_input, reference, retrieved_contexts. + + CONTEXT_RECALL + Measures the proportion of ground-truth information that is covered by + the retrieved context chunks (recall). A high score means the retrieval + pipeline does not miss important supporting evidence. + Required fields: user_input, reference, retrieved_contexts. + + CONTEXT_RELEVANCE + Measures how relevant each individual retrieved chunk is to the query. + Similar to CONTEXT_PRECISION but evaluated at the chunk level rather + than against a reference answer. + Required fields: user_input, retrieved_contexts. + + ── Agent-quality metrics ──────────────────────────────────────────────── + TOOL_CORRECTNESS + Measures the correctness of the tool calls made by the agent during + task execution — both the choice of tool and the arguments passed. + A high score means the agent's tool-use strategy matches the expected + behavior. + Required fields: actual tool calls vs. expected tool calls. + + TASK_COMPLETION + Measures whether the agent ultimately achieves the user's stated goal. + It evaluates the reasoning chain, intermediate steps, and final output + holistically; a high score means the task was fully accomplished. + Required fields: user_input, actual_output. + """ + + # LLM / general text-quality metrics + FAITHFULNESS = "faithfulness" + ANSWER_RELEVANCY = "answer_relevancy" + ANSWER_CORRECTNESS = "answer_correctness" + SEMANTIC_SIMILARITY = "semantic_similarity" + + # Retrieval-quality metrics + CONTEXT_PRECISION = "context_precision" + CONTEXT_RECALL = "context_recall" + CONTEXT_RELEVANCE = "context_relevance" + + # Agent-quality metrics + TOOL_CORRECTNESS = "tool_correctness" + TASK_COMPLETION = "task_completion" + + +# Per-category canonical metric lists used by get_supported_metrics(). +LLM_METRIC_NAMES: list[EvaluationMetricName] = [ + EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations + EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question + EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference + EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer +] + +RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [ + EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision) + EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall) + EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query +] + +AGENT_METRIC_NAMES: list[EvaluationMetricName] = [ + EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments + EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal +] + +WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [ + EvaluationMetricName.FAITHFULNESS, + EvaluationMetricName.ANSWER_RELEVANCY, + EvaluationMetricName.ANSWER_CORRECTNESS, +] + +METRIC_NODE_TYPE_MAPPING: dict[str, str] = { + **{m.value: "llm" for m in LLM_METRIC_NAMES}, + **{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES}, + **{m.value: "agent" for m in AGENT_METRIC_NAMES}, +} + +METRIC_VALUE_TYPE_MAPPING: dict[str, str] = { + EvaluationMetricName.FAITHFULNESS: "number", + EvaluationMetricName.ANSWER_RELEVANCY: "number", + EvaluationMetricName.ANSWER_CORRECTNESS: "number", + EvaluationMetricName.SEMANTIC_SIMILARITY: "number", + EvaluationMetricName.CONTEXT_PRECISION: "number", + EvaluationMetricName.CONTEXT_RECALL: "number", + EvaluationMetricName.CONTEXT_RELEVANCE: "number", + EvaluationMetricName.TOOL_CORRECTNESS: "number", + EvaluationMetricName.TASK_COMPLETION: "number", +} + + +class NodeInfo(BaseModel): + node_id: str + type: str + title: str + + +class EvaluationMetric(BaseModel): + name: str + value: Any + details: dict[str, Any] = Field(default_factory=dict) + node_info: NodeInfo | None = None + + +class EvaluationItemInput(BaseModel): + index: int + inputs: dict[str, Any] + output: str + expected_output: str | None = None + context: list[str] | None = None + + +class EvaluationDatasetInput(BaseModel): + index: int + inputs: dict[str, Any] + expected_output: str | None = None + + +class EvaluationItemResult(BaseModel): + index: int + actual_output: str | None = None + metrics: list[EvaluationMetric] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + judgment: JudgmentResult = Field(default_factory=JudgmentResult) + error: str | None = None + + +class DefaultMetric(BaseModel): + metric: str + value_type: str = "" + node_info_list: list[NodeInfo] + + +class CustomizedMetricOutputField(BaseModel): + variable: str + value_type: str + + +class CustomizedMetrics(BaseModel): + evaluation_workflow_id: str + input_fields: dict[str, Any] + output_fields: list[CustomizedMetricOutputField] + + +class EvaluationConfigData(BaseModel): + """Structured data for saving evaluation configuration.""" + + evaluation_model: str = "" + evaluation_model_provider: str = "" + default_metrics: list[DefaultMetric] = Field(default_factory=list) + customized_metrics: CustomizedMetrics | None = None + judgment_config: JudgmentConfig | None = None + + +class EvaluationRunRequest(EvaluationConfigData): + """Request body for starting an evaluation run.""" + + file_id: str + + +class EvaluationRunData(BaseModel): + """Serializable data for Celery task.""" + + evaluation_run_id: str + tenant_id: str + target_type: str + target_id: str + evaluation_model_provider: str + evaluation_model: str + default_metrics: list[DefaultMetric] = Field(default_factory=list) + customized_metrics: CustomizedMetrics | None = None + judgment_config: JudgmentConfig | None = None + input_list: list[EvaluationDatasetInput] diff --git a/api/core/evaluation/entities/judgment_entity.py b/api/core/evaluation/entities/judgment_entity.py new file mode 100644 index 0000000000..4a59879c06 --- /dev/null +++ b/api/core/evaluation/entities/judgment_entity.py @@ -0,0 +1,96 @@ +"""Judgment condition entities for evaluation metric assessment. + +Condition structure mirrors the workflow if-else ``Condition`` model from +``graphon.utils.condition.entities``. The left-hand side uses +``variable_selector`` — a two-element list ``[node_id, metric_name]`` — to +uniquely identify an evaluation metric (different nodes may produce metrics +with the same name). + +Operators reuse ``SupportedComparisonOperator`` from the workflow engine so +that type semantics stay consistent across the platform. + +Typical usage:: + + judgment_config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["node_abc", "faithfulness"], + comparison_operator=">", + value="0.8", + ) + ], + ) +""" + +from collections.abc import Sequence +from typing import Any, Literal + +from pydantic import BaseModel, Field + +from graphon.utils.condition.entities import SupportedComparisonOperator + + +class JudgmentCondition(BaseModel): + """A single judgment condition that checks one metric value. + + Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand + side being a metric selector instead of a workflow variable selector. + + Attributes: + variable_selector: ``[node_id, metric_name]`` identifying the metric. + comparison_operator: Reuses workflow's ``SupportedComparisonOperator``. + value: The comparison target (right side). For unary operators such + as ``empty`` or ``null`` this can be ``None``. + """ + + variable_selector: list[str] + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | bool | None = None + + +class JudgmentConfig(BaseModel): + """A group of judgment conditions combined with a logical operator. + + Attributes: + logical_operator: How to combine condition results — "and" requires + all conditions to pass, "or" requires at least one. + conditions: The list of individual conditions to evaluate. + """ + + logical_operator: Literal["and", "or"] = "and" + conditions: list[JudgmentCondition] = Field(default_factory=list) + + +class JudgmentConditionResult(BaseModel): + """Result of evaluating a single judgment condition. + + Attributes: + variable_selector: ``[node_id, metric_name]`` that was checked. + comparison_operator: The operator that was applied. + expected_value: The resolved comparison value. + actual_value: The actual metric value that was evaluated. + passed: Whether this individual condition passed. + error: Error message if the condition evaluation failed. + """ + + variable_selector: list[str] + comparison_operator: str + expected_value: Any = None + actual_value: Any = None + passed: bool = False + error: str | None = None + + +class JudgmentResult(BaseModel): + """Overall result of evaluating all judgment conditions for one item. + + Attributes: + passed: Whether the overall judgment passed (based on logical_operator). + logical_operator: The logical operator used to combine conditions. + condition_results: Detailed result for each individual condition. + """ + + passed: bool = False + logical_operator: Literal["and", "or"] = "and" + condition_results: list[JudgmentConditionResult] = Field(default_factory=list) diff --git a/api/core/evaluation/evaluation_manager.py b/api/core/evaluation/evaluation_manager.py new file mode 100644 index 0000000000..5499b96cad --- /dev/null +++ b/api/core/evaluation/evaluation_manager.py @@ -0,0 +1,61 @@ +import collections +import logging +from typing import Any + +from configs import dify_config +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.config_entity import EvaluationFrameworkEnum +from core.evaluation.entities.evaluation_entity import EvaluationCategory + +logger = logging.getLogger(__name__) + + +class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]): + """Registry mapping framework enum -> {config_class, evaluator_class}.""" + + def __getitem__(self, framework: str) -> dict[str, Any]: + match framework: + case EvaluationFrameworkEnum.RAGAS: + from core.evaluation.entities.config_entity import RagasConfig + from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator + + return { + "config_class": RagasConfig, + "evaluator_class": RagasEvaluator, + } + case EvaluationFrameworkEnum.DEEPEVAL: + raise NotImplementedError("DeepEval adapter is not yet implemented.") + case _: + raise ValueError(f"Unknown evaluation framework: {framework}") + + +evaluation_framework_config_map = EvaluationFrameworkConfigMap() + + +class EvaluationManager: + """Factory for evaluation instances based on global configuration.""" + + @staticmethod + def get_evaluation_instance() -> BaseEvaluationInstance | None: + """Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var.""" + framework = dify_config.EVALUATION_FRAMEWORK + if not framework or framework == EvaluationFrameworkEnum.NONE: + return None + + try: + config_map = evaluation_framework_config_map[framework] + evaluator_class = config_map["evaluator_class"] + config_class = config_map["config_class"] + config = config_class() + return evaluator_class(config) + except Exception: + logger.exception("Failed to create evaluation instance for framework: %s", framework) + return None + + @staticmethod + def get_supported_metrics(category: EvaluationCategory) -> list[str]: + """Return supported metrics for the current framework and given category.""" + instance = EvaluationManager.get_evaluation_instance() + if instance is None: + return [] + return instance.get_supported_metrics(category) diff --git a/api/core/evaluation/frameworks/__init__.py b/api/core/evaluation/frameworks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/evaluation/frameworks/deepeval/__init__.py b/api/core/evaluation/frameworks/deepeval/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/core/evaluation/frameworks/deepeval/__init__.py @@ -0,0 +1 @@ + diff --git a/api/core/evaluation/frameworks/deepeval/deepeval_evaluator.py b/api/core/evaluation/frameworks/deepeval/deepeval_evaluator.py new file mode 100644 index 0000000000..e0c65792f3 --- /dev/null +++ b/api/core/evaluation/frameworks/deepeval/deepeval_evaluator.py @@ -0,0 +1,299 @@ +import logging +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.config_entity import DeepEvalConfig +from core.evaluation.entities.evaluation_entity import ( + AGENT_METRIC_NAMES, + LLM_METRIC_NAMES, + RETRIEVAL_METRIC_NAMES, + WORKFLOW_METRIC_NAMES, + EvaluationCategory, + EvaluationItemInput, + EvaluationItemResult, + EvaluationMetric, + EvaluationMetricName, +) +from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper + +logger = logging.getLogger(__name__) + +# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name. +# deepeval metric field requirements (LLMTestCase fields): +# - faithfulness: input, actual_output, retrieval_context +# - answer_relevancy: input, actual_output +# - context_precision: input, actual_output, expected_output, retrieval_context +# - context_recall: input, actual_output, expected_output, retrieval_context +# - context_relevance: input, actual_output, retrieval_context +# - tool_correctness: input, actual_output, expected_tools +# - task_completion: input, actual_output +# Metrics not listed here are unsupported by deepeval and will be skipped. +_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = { + EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric", + EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric", + EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric", + EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric", + EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric", + EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric", + EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric", +} + + +class DeepEvalEvaluator(BaseEvaluationInstance): + """DeepEval framework adapter for evaluation.""" + + def __init__(self, config: DeepEvalConfig): + self.config = config + + def get_supported_metrics(self, category: EvaluationCategory) -> list[str]: + match category: + case EvaluationCategory.LLM: + candidates = LLM_METRIC_NAMES + case EvaluationCategory.RETRIEVAL: + candidates = RETRIEVAL_METRIC_NAMES + case EvaluationCategory.AGENT: + candidates = AGENT_METRIC_NAMES + case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET: + candidates = WORKFLOW_METRIC_NAMES + case _: + return [] + return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP] + + def evaluate_llm( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM) + + def evaluate_retrieval( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL) + + def evaluate_agent( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT) + + def evaluate_workflow( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW) + + def _evaluate( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + category: EvaluationCategory, + ) -> list[EvaluationItemResult]: + """Core evaluation logic using DeepEval.""" + model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id) + requested_metrics = metric_names or self.get_supported_metrics(category) + + try: + return self._evaluate_with_deepeval(items, requested_metrics, category) + except ImportError: + logger.warning("DeepEval not installed, falling back to simple evaluation") + return self._evaluate_simple(items, requested_metrics, model_wrapper) + + def _evaluate_with_deepeval( + self, + items: list[EvaluationItemInput], + requested_metrics: list[str], + category: EvaluationCategory, + ) -> list[EvaluationItemResult]: + """Evaluate using DeepEval library. + + Builds LLMTestCase differently per category: + - LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context + - Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context + - Agent: input=query, actual_output=output + """ + metric_pairs = _build_deepeval_metrics(requested_metrics) + if not metric_pairs: + logger.warning("No valid DeepEval metrics found for: %s", requested_metrics) + return [EvaluationItemResult(index=item.index) for item in items] + + results: list[EvaluationItemResult] = [] + for item in items: + test_case = self._build_test_case(item, category) + metrics: list[EvaluationMetric] = [] + for canonical_name, metric in metric_pairs: + try: + metric.measure(test_case) + if metric.score is not None: + metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score))) + except Exception: + logger.exception( + "Failed to compute metric %s for item %d", + canonical_name, + item.index, + ) + results.append(EvaluationItemResult(index=item.index, metrics=metrics)) + return results + + @staticmethod + def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any: + """Build a deepeval LLMTestCase with the correct fields per category.""" + from deepeval.test_case import LLMTestCase + + user_input = _format_input(item.inputs, category) + + match category: + case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW: + # faithfulness needs: input, actual_output, retrieval_context + # answer_relevancy needs: input, actual_output + return LLMTestCase( + input=user_input, + actual_output=item.output, + expected_output=item.expected_output or None, + retrieval_context=item.context or None, + ) + case EvaluationCategory.RETRIEVAL: + # contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context + return LLMTestCase( + input=user_input, + actual_output=item.output or "", + expected_output=item.expected_output or "", + retrieval_context=item.context or [], + ) + case _: + return LLMTestCase( + input=user_input, + actual_output=item.output, + ) + + def _evaluate_simple( + self, + items: list[EvaluationItemInput], + requested_metrics: list[str], + model_wrapper: DifyModelWrapper, + ) -> list[EvaluationItemResult]: + """Simple LLM-as-judge fallback when DeepEval is not available.""" + results: list[EvaluationItemResult] = [] + for item in items: + metrics: list[EvaluationMetric] = [] + for m_name in requested_metrics: + try: + score = self._judge_with_llm(model_wrapper, m_name, item) + metrics.append(EvaluationMetric(name=m_name, value=score)) + except Exception: + logger.exception("Failed to compute metric %s for item %d", m_name, item.index) + results.append(EvaluationItemResult(index=item.index, metrics=metrics)) + return results + + def _judge_with_llm( + self, + model_wrapper: DifyModelWrapper, + metric_name: str, + item: EvaluationItemInput, + ) -> float: + """Use the LLM to judge a single metric for a single item.""" + prompt = self._build_judge_prompt(metric_name, item) + response = model_wrapper.invoke(prompt) + return self._parse_score(response) + + @staticmethod + def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str: + """Build a scoring prompt for the LLM judge.""" + parts = [ + f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.", + f"\nInput: {item.inputs}", + f"\nOutput: {item.output}", + ] + if item.expected_output: + parts.append(f"\nExpected Output: {item.expected_output}") + if item.context: + parts.append(f"\nContext: {'; '.join(item.context)}") + parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.") + return "\n".join(parts) + + @staticmethod + def _parse_score(response: str) -> float: + """Parse a float score from LLM response.""" + import re + + cleaned = response.strip() + try: + score = float(cleaned) + return max(0.0, min(1.0, score)) + except ValueError: + match = re.search(r"(\d+\.?\d*)", cleaned) + if match: + score = float(match.group(1)) + return max(0.0, min(1.0, score)) + return 0.0 + + +def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str: + """Extract the user-facing input string from the inputs dict.""" + match category: + case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW: + return str(inputs.get("prompt", "")) + case EvaluationCategory.RETRIEVAL: + return str(inputs.get("query", "")) + case _: + return str(next(iter(inputs.values()), "")) if inputs else "" + + +def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]: + """Build DeepEval metric instances from canonical metric names. + + Returns a list of (canonical_name, metric_instance) pairs so that callers + can record the canonical name rather than the framework-internal class name. + """ + try: + from deepeval.metrics import ( + AnswerRelevancyMetric, + ContextualPrecisionMetric, + ContextualRecallMetric, + ContextualRelevancyMetric, + FaithfulnessMetric, + TaskCompletionMetric, + ToolCorrectnessMetric, + ) + + # Maps canonical name → deepeval metric class + deepeval_class_map: dict[str, Any] = { + EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric, + EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric, + EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric, + EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric, + EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric, + EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric, + EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric, + } + + pairs: list[tuple[str, Any]] = [] + for name in requested_metrics: + metric_class = deepeval_class_map.get(name) + if metric_class: + pairs.append((name, metric_class(threshold=0.5))) + else: + logger.warning("Metric '%s' is not supported by DeepEval, skipping", name) + return pairs + except ImportError: + logger.warning("DeepEval metrics not available") + return [] diff --git a/api/core/evaluation/frameworks/ragas/__init__.py b/api/core/evaluation/frameworks/ragas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/evaluation/frameworks/ragas/ragas_evaluator.py b/api/core/evaluation/frameworks/ragas/ragas_evaluator.py new file mode 100644 index 0000000000..ec2320439d --- /dev/null +++ b/api/core/evaluation/frameworks/ragas/ragas_evaluator.py @@ -0,0 +1,312 @@ +import logging +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.config_entity import RagasConfig +from core.evaluation.entities.evaluation_entity import ( + AGENT_METRIC_NAMES, + LLM_METRIC_NAMES, + RETRIEVAL_METRIC_NAMES, + WORKFLOW_METRIC_NAMES, + EvaluationCategory, + EvaluationItemInput, + EvaluationItemResult, + EvaluationMetric, + EvaluationMetricName, +) +from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper + +logger = logging.getLogger(__name__) + +# Maps canonical EvaluationMetricName to the corresponding ragas metric class. +# Metrics not listed here are unsupported by ragas and will be skipped. +_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = { + EvaluationMetricName.FAITHFULNESS: "Faithfulness", + EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy", + EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness", + EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity", + EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision", + EvaluationMetricName.CONTEXT_RECALL: "ContextRecall", + EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance", + EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy", +} + + +class RagasEvaluator(BaseEvaluationInstance): + """RAGAS framework adapter for evaluation.""" + + def __init__(self, config: RagasConfig): + self.config = config + + def get_supported_metrics(self, category: EvaluationCategory) -> list[str]: + match category: + case EvaluationCategory.LLM: + candidates = LLM_METRIC_NAMES + case EvaluationCategory.RETRIEVAL: + candidates = RETRIEVAL_METRIC_NAMES + case EvaluationCategory.AGENT: + candidates = AGENT_METRIC_NAMES + case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET: + candidates = WORKFLOW_METRIC_NAMES + case _: + return [] + return [m for m in candidates if m in _RAGAS_METRIC_MAP] + + def evaluate_llm( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM) + + def evaluate_retrieval( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL) + + def evaluate_agent( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT) + + def evaluate_workflow( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW) + + def _evaluate( + self, + items: list[EvaluationItemInput], + metric_names: list[str], + model_provider: str, + model_name: str, + tenant_id: str, + category: EvaluationCategory, + ) -> list[EvaluationItemResult]: + """Core evaluation logic using RAGAS.""" + model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id) + requested_metrics = metric_names or self.get_supported_metrics(category) + + try: + return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category) + except ImportError: + logger.warning("RAGAS not installed, falling back to simple evaluation") + return self._evaluate_simple(items, requested_metrics, model_wrapper) + + def _evaluate_with_ragas( + self, + items: list[EvaluationItemInput], + requested_metrics: list[str], + model_wrapper: DifyModelWrapper, + category: EvaluationCategory, + ) -> list[EvaluationItemResult]: + """Evaluate using RAGAS library. + + Builds SingleTurnSample differently per category to match ragas requirements: + - LLM/Workflow: user_input=prompt, response=output, reference=expected_output + - Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context + - Agent: Not supported via EvaluationDataset (requires message-based API) + """ + from ragas import evaluate as ragas_evaluate + from ragas.dataset_schema import EvaluationDataset + + samples: list[Any] = [] + for item in items: + sample = self._build_sample(item, category) + samples.append(sample) + + dataset = EvaluationDataset(samples=samples) + + ragas_metrics = self._build_ragas_metrics(requested_metrics) + if not ragas_metrics: + logger.warning("No valid RAGAS metrics found for: %s", requested_metrics) + return [EvaluationItemResult(index=item.index) for item in items] + + try: + result = ragas_evaluate( + dataset=dataset, + metrics=ragas_metrics, + ) + + results: list[EvaluationItemResult] = [] + result_df = result.to_pandas() + for i, item in enumerate(items): + metrics: list[EvaluationMetric] = [] + for m_name in requested_metrics: + if m_name in result_df.columns: + score = result_df.iloc[i][m_name] + if score is not None and not (isinstance(score, float) and score != score): + metrics.append(EvaluationMetric(name=m_name, value=float(score))) + results.append(EvaluationItemResult(index=item.index, metrics=metrics)) + return results + except Exception: + logger.exception("RAGAS evaluation failed, falling back to simple evaluation") + return self._evaluate_simple(items, requested_metrics, model_wrapper) + + @staticmethod + def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any: + """Build a ragas SingleTurnSample with the correct fields per category. + + ragas metric field requirements: + - faithfulness: user_input, response, retrieved_contexts + - answer_relevancy: user_input, response + - answer_correctness: user_input, response, reference + - semantic_similarity: user_input, response, reference + - context_precision: user_input, reference, retrieved_contexts + - context_recall: user_input, reference, retrieved_contexts + - context_relevance: user_input, retrieved_contexts + """ + from ragas.dataset_schema import SingleTurnSample + + user_input = _format_input(item.inputs, category) + + match category: + case EvaluationCategory.LLM: + # response = actual LLM output, reference = expected output + return SingleTurnSample( + user_input=user_input, + response=item.output, + reference=item.expected_output or "", + retrieved_contexts=item.context or [], + ) + case EvaluationCategory.RETRIEVAL: + # context_precision/recall only need reference + retrieved_contexts + return SingleTurnSample( + user_input=user_input, + reference=item.expected_output or "", + retrieved_contexts=item.context or [], + ) + case _: + return SingleTurnSample( + user_input=user_input, + response=item.output, + ) + + def _evaluate_simple( + self, + items: list[EvaluationItemInput], + requested_metrics: list[str], + model_wrapper: DifyModelWrapper, + ) -> list[EvaluationItemResult]: + """Simple LLM-as-judge fallback when RAGAS is not available.""" + results: list[EvaluationItemResult] = [] + for item in items: + metrics: list[EvaluationMetric] = [] + for m_name in requested_metrics: + try: + score = self._judge_with_llm(model_wrapper, m_name, item) + metrics.append(EvaluationMetric(name=m_name, value=score)) + except Exception: + logger.exception("Failed to compute metric %s for item %d", m_name, item.index) + results.append(EvaluationItemResult(index=item.index, metrics=metrics)) + return results + + def _judge_with_llm( + self, + model_wrapper: DifyModelWrapper, + metric_name: str, + item: EvaluationItemInput, + ) -> float: + """Use the LLM to judge a single metric for a single item.""" + prompt = self._build_judge_prompt(metric_name, item) + response = model_wrapper.invoke(prompt) + return self._parse_score(response) + + @staticmethod + def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str: + """Build a scoring prompt for the LLM judge.""" + parts = [ + f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.", + f"\nInput: {item.inputs}", + f"\nOutput: {item.output}", + ] + if item.expected_output: + parts.append(f"\nExpected Output: {item.expected_output}") + if item.context: + parts.append(f"\nContext: {'; '.join(item.context)}") + parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.") + return "\n".join(parts) + + @staticmethod + def _parse_score(response: str) -> float: + """Parse a float score from LLM response.""" + import re + + cleaned = response.strip() + try: + score = float(cleaned) + return max(0.0, min(1.0, score)) + except ValueError: + match = re.search(r"(\d+\.?\d*)", cleaned) + if match: + score = float(match.group(1)) + return max(0.0, min(1.0, score)) + return 0.0 + + @staticmethod + def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]: + """Build RAGAS metric instances from canonical metric names.""" + try: + from ragas.metrics.collections import ( + AnswerCorrectness, + AnswerRelevancy, + ContextPrecision, + ContextRecall, + ContextRelevance, + Faithfulness, + SemanticSimilarity, + ToolCallAccuracy, + ) + + # Maps canonical name → ragas metric class + ragas_class_map: dict[str, Any] = { + EvaluationMetricName.FAITHFULNESS: Faithfulness, + EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancy, + EvaluationMetricName.ANSWER_CORRECTNESS: AnswerCorrectness, + EvaluationMetricName.SEMANTIC_SIMILARITY: SemanticSimilarity, + EvaluationMetricName.CONTEXT_PRECISION: ContextPrecision, + EvaluationMetricName.CONTEXT_RECALL: ContextRecall, + EvaluationMetricName.CONTEXT_RELEVANCE: ContextRelevance, + EvaluationMetricName.TOOL_CORRECTNESS: ToolCallAccuracy, + } + + metrics = [] + for name in requested_metrics: + metric_class = ragas_class_map.get(name) + if metric_class: + metrics.append(metric_class()) + else: + logger.warning("Metric '%s' is not supported by RAGAS, skipping", name) + return metrics + except ImportError: + logger.warning("RAGAS metrics not available") + return [] + + +def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str: + """Extract the user-facing input string from the inputs dict.""" + match category: + case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW: + return str(inputs.get("prompt", "")) + case EvaluationCategory.RETRIEVAL: + return str(inputs.get("query", "")) + case _: + return str(next(iter(inputs.values()), "")) if inputs else "" diff --git a/api/core/evaluation/frameworks/ragas/ragas_model_wrapper.py b/api/core/evaluation/frameworks/ragas/ragas_model_wrapper.py new file mode 100644 index 0000000000..e0a5e14914 --- /dev/null +++ b/api/core/evaluation/frameworks/ragas/ragas_model_wrapper.py @@ -0,0 +1,48 @@ +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class DifyModelWrapper: + """Wraps Dify's model invocation interface for use by RAGAS as an LLM judge. + + RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.). + This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use. + """ + + def __init__(self, model_provider: str, model_name: str, tenant_id: str): + self.model_provider = model_provider + self.model_name = model_name + self.tenant_id = tenant_id + + def _get_model_instance(self) -> Any: + from core.model_manager import ModelManager + from core.model_runtime.entities.model_entities import ModelType + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.model_provider, + model_type=ModelType.LLM, + model=self.model_name, + ) + return model_instance + + def invoke(self, prompt: str) -> str: + """Invoke the model with a text prompt and return the text response.""" + from core.model_runtime.entities.message_entities import ( + SystemPromptMessage, + UserPromptMessage, + ) + + model_instance = self._get_model_instance() + result = model_instance.invoke_llm( + prompt_messages=[ + SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."), + UserPromptMessage(content=prompt), + ], + model_parameters={"temperature": 0.0, "max_tokens": 2048}, + stream=False, + ) + return result.message.content diff --git a/api/core/evaluation/judgment/__init__.py b/api/core/evaluation/judgment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/evaluation/judgment/processor.py b/api/core/evaluation/judgment/processor.py new file mode 100644 index 0000000000..7a0ce38b75 --- /dev/null +++ b/api/core/evaluation/judgment/processor.py @@ -0,0 +1,160 @@ +"""Judgment condition processor for evaluation metrics. + +Evaluates pass/fail judgment conditions against evaluation metric values. +Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to +look up the metric value, then delegates the actual comparison to the +workflow condition engine (``graphon.utils.condition.processor``). + +The processor is intentionally decoupled from evaluation frameworks and +runners. It operates on plain ``dict`` mappings and can be invoked +anywhere that already has per-item metric results. +""" + +import logging +from collections.abc import Sequence +from typing import Any, cast + +from core.evaluation.entities.judgment_entity import ( + JudgmentCondition, + JudgmentConditionResult, + JudgmentConfig, + JudgmentResult, +) +from graphon.utils.condition.entities import SupportedComparisonOperator +from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage] + +logger = logging.getLogger(__name__) + +_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"}) + + +class JudgmentProcessor: + @staticmethod + def evaluate( + metric_values: dict[tuple[str, str], Any], + config: JudgmentConfig, + ) -> JudgmentResult: + """Evaluate all judgment conditions against the given metric values. + + Args: + metric_values: Mapping of ``(node_id, metric_name)`` → metric + value (e.g. ``{("node_abc", "faithfulness"): 0.85}``). + config: The judgment configuration with logical_operator and + conditions. + + Returns: + JudgmentResult with overall pass/fail and per-condition details. + """ + if not config.conditions: + return JudgmentResult( + passed=True, + logical_operator=config.logical_operator, + condition_results=[], + ) + + condition_results: list[JudgmentConditionResult] = [] + + for condition in config.conditions: + result = JudgmentProcessor._evaluate_single_condition(metric_values, condition) + condition_results.append(result) + + if config.logical_operator == "and" and not result.passed: + return JudgmentResult( + passed=False, + logical_operator=config.logical_operator, + condition_results=condition_results, + ) + if config.logical_operator == "or" and result.passed: + return JudgmentResult( + passed=True, + logical_operator=config.logical_operator, + condition_results=condition_results, + ) + + if config.logical_operator == "and": + final_passed = all(r.passed for r in condition_results) + else: + final_passed = any(r.passed for r in condition_results) + + return JudgmentResult( + passed=final_passed, + logical_operator=config.logical_operator, + condition_results=condition_results, + ) + + @staticmethod + def _evaluate_single_condition( + metric_values: dict[tuple[str, str], Any], + condition: JudgmentCondition, + ) -> JudgmentConditionResult: + """Evaluate a single judgment condition. + + Steps: + 1. Extract ``(node_id, metric_name)`` from ``variable_selector``. + 2. Look up the metric value from ``metric_values``. + 3. Delegate comparison to the workflow condition engine. + """ + selector = condition.variable_selector + if len(selector) < 2: + return JudgmentConditionResult( + variable_selector=selector, + comparison_operator=condition.comparison_operator, + expected_value=condition.value, + actual_value=None, + passed=False, + error=f"variable_selector must have at least 2 elements, got {len(selector)}", + ) + + node_id, metric_name = selector[0], selector[1] + actual_value = metric_values.get((node_id, metric_name)) + + if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS: + return JudgmentConditionResult( + variable_selector=selector, + comparison_operator=condition.comparison_operator, + expected_value=condition.value, + actual_value=None, + passed=False, + error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results", + ) + + try: + expected = condition.value + # Numeric operators need the actual value coerced to int/float + # so that the workflow engine's numeric assertions work correctly. + coerced_actual: object = actual_value + if ( + condition.comparison_operator in {"=", "≠", ">", "<", "≥", "≤"} + and actual_value is not None + and not isinstance(actual_value, (int, float, bool)) + ): + coerced_actual = float(actual_value) + + passed = _evaluate_condition( + operator=cast(SupportedComparisonOperator, condition.comparison_operator), + value=coerced_actual, + expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected), + ) + + return JudgmentConditionResult( + variable_selector=selector, + comparison_operator=condition.comparison_operator, + expected_value=expected, + actual_value=actual_value, + passed=passed, + ) + except Exception as e: + logger.warning( + "Judgment condition evaluation failed for [%s, %s]: %s", + node_id, + metric_name, + str(e), + ) + return JudgmentConditionResult( + variable_selector=selector, + comparison_operator=condition.comparison_operator, + expected_value=condition.value, + actual_value=actual_value, + passed=False, + error=str(e), + ) diff --git a/api/core/evaluation/runners/__init__.py b/api/core/evaluation/runners/__init__.py new file mode 100644 index 0000000000..0a69087432 --- /dev/null +++ b/api/core/evaluation/runners/__init__.py @@ -0,0 +1,52 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models import Account, App, CustomizedSnippet, TenantAccountJoin + + +def get_service_account_for_app(session: Session, app_id: str) -> Account: + """Get the creator account for an app with tenant context set up. + + This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant(). + """ + app = session.scalar(select(App).where(App.id == app_id)) + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator") + + account = session.scalar(select(Account).where(Account.id == app.created_by)) + if not account: + raise ValueError(f"Creator account not found for app {app_id}") + + current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() + if not current_tenant: + raise ValueError(f"Current tenant not found for account {account.id}") + + account.set_tenant_id(current_tenant.tenant_id) + return account + + +def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account: + """Get the creator account for a snippet with tenant context set up. + + Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet. + """ + snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id)) + if not snippet: + raise ValueError(f"Snippet with id {snippet_id} not found") + + if not snippet.created_by: + raise ValueError(f"Snippet with id {snippet_id} has no creator") + + account = session.scalar(select(Account).where(Account.id == snippet.created_by)) + if not account: + raise ValueError(f"Creator account not found for snippet {snippet_id}") + + current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() + if not current_tenant: + raise ValueError(f"Current tenant not found for account {account.id}") + + account.set_tenant_id(current_tenant.tenant_id) + return account diff --git a/api/core/evaluation/runners/agent_evaluation_runner.py b/api/core/evaluation/runners/agent_evaluation_runner.py new file mode 100644 index 0000000000..ef3bbe704c --- /dev/null +++ b/api/core/evaluation/runners/agent_evaluation_runner.py @@ -0,0 +1,62 @@ +import logging +from collections.abc import Mapping +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + DefaultMetric, + EvaluationItemInput, + EvaluationItemResult, +) +from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner +from graphon.node_events import NodeRunResult + +logger = logging.getLogger(__name__) + + +class AgentEvaluationRunner(BaseEvaluationRunner): + """Runner for agent evaluation: collects tool calls and final output.""" + + def __init__(self, evaluation_instance: BaseEvaluationInstance): + super().__init__(evaluation_instance) + + def evaluate_metrics( + self, + node_run_result_list: list[NodeRunResult], + default_metric: DefaultMetric, + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Compute agent evaluation metrics.""" + if not node_run_result_list: + return [] + merged_items = self._merge_results_into_items(node_run_result_list) + return self.evaluation_instance.evaluate_agent( + merged_items, [default_metric.metric], model_provider, model_name, tenant_id + ) + + @staticmethod + def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]: + """Create EvaluationItemInput list from NodeRunResult for agent evaluation.""" + merged = [] + for i, item in enumerate(items): + output = _extract_agent_output(item.outputs) + merged.append( + EvaluationItemInput( + index=i, + inputs=dict(item.inputs), + output=output, + ) + ) + return merged + + +def _extract_agent_output(outputs: Mapping[str, Any]) -> str: + """Extract the primary output text from agent NodeRunResult.outputs.""" + if "answer" in outputs: + return str(outputs["answer"]) + if "text" in outputs: + return str(outputs["text"]) + values = list(outputs.values()) + return str(values[0]) if values else "" diff --git a/api/core/evaluation/runners/base_evaluation_runner.py b/api/core/evaluation/runners/base_evaluation_runner.py new file mode 100644 index 0000000000..9046c2ddad --- /dev/null +++ b/api/core/evaluation/runners/base_evaluation_runner.py @@ -0,0 +1,51 @@ +"""Base evaluation runner. + +Provides the abstract interface for metric computation. Each concrete runner +(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics`` +to compute scores for a specific node type. + +Orchestration (merging results from multiple runners, applying judgment, and +persisting to the database) is handled by the evaluation task, not the runner. +""" + +import logging +from abc import ABC, abstractmethod + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + DefaultMetric, + EvaluationItemResult, +) +from graphon.node_events import NodeRunResult + +logger = logging.getLogger(__name__) + + +class BaseEvaluationRunner(ABC): + """Abstract base class for evaluation runners. + + Runners are stateless metric calculators: they receive node execution + results and a metric specification, then return scored results. They + do **not** touch the database or apply judgment logic. + """ + + def __init__(self, evaluation_instance: BaseEvaluationInstance): + self.evaluation_instance = evaluation_instance + + @abstractmethod + def evaluate_metrics( + self, + node_run_result_list: list[NodeRunResult], + default_metric: DefaultMetric, + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Compute evaluation metrics on the collected results. + + The returned ``EvaluationItemResult.index`` values are positional + (0-based) and correspond to the order of *node_run_result_list*. + The caller is responsible for mapping them back to the original + dataset indices. + """ + ... diff --git a/api/core/evaluation/runners/llm_evaluation_runner.py b/api/core/evaluation/runners/llm_evaluation_runner.py new file mode 100644 index 0000000000..4b1c244838 --- /dev/null +++ b/api/core/evaluation/runners/llm_evaluation_runner.py @@ -0,0 +1,83 @@ +import logging +from collections.abc import Mapping +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + DefaultMetric, + EvaluationItemInput, + EvaluationItemResult, +) +from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner +from graphon.node_events import NodeRunResult + +logger = logging.getLogger(__name__) + + +class LLMEvaluationRunner(BaseEvaluationRunner): + """Runner for LLM evaluation: extracts prompts/outputs then evaluates.""" + + def __init__(self, evaluation_instance: BaseEvaluationInstance): + super().__init__(evaluation_instance) + + def evaluate_metrics( + self, + node_run_result_list: list[NodeRunResult], + default_metric: DefaultMetric, + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Use the evaluation instance to compute LLM metrics.""" + if not node_run_result_list: + return [] + merged_items = self._merge_results_into_items(node_run_result_list) + return self.evaluation_instance.evaluate_llm( + merged_items, [default_metric.metric], model_provider, model_name, tenant_id + ) + + @staticmethod + def _merge_results_into_items( + items: list[NodeRunResult], + ) -> list[EvaluationItemInput]: + """Create new items from NodeRunResult for ragas evaluation. + + Extracts prompts from process_data and concatenates them into a single + string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ..."). + The last assistant message in outputs is used as the actual output. + """ + merged = [] + for i, item in enumerate(items): + prompt = _format_prompts(item.process_data.get("prompts", [])) + output = _extract_llm_output(item.outputs) + merged.append( + EvaluationItemInput( + index=i, + inputs={"prompt": prompt}, + output=output, + ) + ) + return merged + + +def _format_prompts(prompts: list[dict[str, Any]]) -> str: + """Concatenate a list of prompt messages into a single string for evaluation. + + Each message is formatted as "role: text" and joined with newlines. + """ + parts: list[str] = [] + for msg in prompts: + role = msg.get("role", "unknown") + text = msg.get("text", "") + parts.append(f"{role}: {text}") + return "\n".join(parts) + + +def _extract_llm_output(outputs: Mapping[str, Any]) -> str: + """Extract the LLM output text from NodeRunResult.outputs.""" + if "text" in outputs: + return str(outputs["text"]) + if "answer" in outputs: + return str(outputs["answer"]) + values = list(outputs.values()) + return str(values[0]) if values else "" diff --git a/api/core/evaluation/runners/retrieval_evaluation_runner.py b/api/core/evaluation/runners/retrieval_evaluation_runner.py new file mode 100644 index 0000000000..66b8ab7360 --- /dev/null +++ b/api/core/evaluation/runners/retrieval_evaluation_runner.py @@ -0,0 +1,61 @@ +import logging +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + DefaultMetric, + EvaluationItemInput, + EvaluationItemResult, +) +from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner +from graphon.node_events import NodeRunResult + +logger = logging.getLogger(__name__) + + +class RetrievalEvaluationRunner(BaseEvaluationRunner): + """Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates.""" + + def __init__(self, evaluation_instance: BaseEvaluationInstance): + super().__init__(evaluation_instance) + + def evaluate_metrics( + self, + node_run_result_list: list[NodeRunResult], + default_metric: DefaultMetric, + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Compute retrieval evaluation metrics.""" + if not node_run_result_list: + return [] + + merged_items = [] + for i, node_result in enumerate(node_run_result_list): + outputs = node_result.outputs + query = self._extract_query(dict(node_result.inputs)) + result_list = outputs.get("result", []) + contexts = [item.get("content", "") for item in result_list if item.get("content")] + output = "\n---\n".join(contexts) + + merged_items.append( + EvaluationItemInput( + index=i, + inputs={"query": query}, + output=output, + context=contexts, + ) + ) + + return self.evaluation_instance.evaluate_retrieval( + merged_items, [default_metric.metric], model_provider, model_name, tenant_id + ) + + @staticmethod + def _extract_query(inputs: dict[str, Any]) -> str: + for key in ("query", "question", "input", "text"): + if key in inputs: + return str(inputs[key]) + values = list(inputs.values()) + return str(values[0]) if values else "" diff --git a/api/core/evaluation/runners/snippet_evaluation_runner.py b/api/core/evaluation/runners/snippet_evaluation_runner.py new file mode 100644 index 0000000000..bc516f9ee8 --- /dev/null +++ b/api/core/evaluation/runners/snippet_evaluation_runner.py @@ -0,0 +1,68 @@ +"""Runner for Snippet evaluation. + +Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from +the evaluation instance for metric computation. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + DefaultMetric, + EvaluationItemInput, + EvaluationItemResult, +) +from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner +from graphon.node_events import NodeRunResult + +logger = logging.getLogger(__name__) + + +class SnippetEvaluationRunner(BaseEvaluationRunner): + """Runner for snippet evaluation: evaluates a published Snippet workflow.""" + + def __init__(self, evaluation_instance: BaseEvaluationInstance): + super().__init__(evaluation_instance) + + def evaluate_metrics( + self, + node_run_result_list: list[NodeRunResult], + default_metric: DefaultMetric, + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Compute evaluation metrics for snippet outputs.""" + if not node_run_result_list: + return [] + merged_items = self._merge_results_into_items(node_run_result_list) + return self.evaluation_instance.evaluate_workflow( + merged_items, [default_metric.metric], model_provider, model_name, tenant_id + ) + + @staticmethod + def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]: + """Create EvaluationItemInput list from NodeRunResult for snippet evaluation.""" + merged = [] + for i, item in enumerate(items): + output = _extract_snippet_output(item.outputs) + merged.append( + EvaluationItemInput( + index=i, + inputs=dict(item.inputs), + output=output, + ) + ) + return merged + + +def _extract_snippet_output(outputs: Mapping[str, Any]) -> str: + """Extract the primary output text from snippet NodeRunResult.outputs.""" + if "answer" in outputs: + return str(outputs["answer"]) + if "text" in outputs: + return str(outputs["text"]) + values = list(outputs.values()) + return str(values[0]) if values else "" diff --git a/api/core/evaluation/runners/workflow_evaluation_runner.py b/api/core/evaluation/runners/workflow_evaluation_runner.py new file mode 100644 index 0000000000..e1cc9defdb --- /dev/null +++ b/api/core/evaluation/runners/workflow_evaluation_runner.py @@ -0,0 +1,62 @@ +import logging +from collections.abc import Mapping +from typing import Any + +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + DefaultMetric, + EvaluationItemInput, + EvaluationItemResult, +) +from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner +from graphon.node_events import NodeRunResult + +logger = logging.getLogger(__name__) + + +class WorkflowEvaluationRunner(BaseEvaluationRunner): + """Runner for workflow evaluation: executes workflow App in non-streaming mode.""" + + def __init__(self, evaluation_instance: BaseEvaluationInstance): + super().__init__(evaluation_instance) + + def evaluate_metrics( + self, + node_run_result_list: list[NodeRunResult], + default_metric: DefaultMetric, + model_provider: str, + model_name: str, + tenant_id: str, + ) -> list[EvaluationItemResult]: + """Compute workflow evaluation metrics (end-to-end).""" + if not node_run_result_list: + return [] + merged_items = self._merge_results_into_items(node_run_result_list) + return self.evaluation_instance.evaluate_workflow( + merged_items, [default_metric.metric], model_provider, model_name, tenant_id + ) + + @staticmethod + def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]: + """Create EvaluationItemInput list from NodeRunResult for workflow evaluation.""" + merged = [] + for i, item in enumerate(items): + output = _extract_workflow_output(item.outputs) + merged.append( + EvaluationItemInput( + index=i, + inputs=dict(item.inputs), + output=output, + ) + ) + return merged + + +def _extract_workflow_output(outputs: Mapping[str, Any]) -> str: + """Extract the primary output text from workflow NodeRunResult.outputs.""" + if "answer" in outputs: + return str(outputs["answer"]) + if "text" in outputs: + return str(outputs["text"]) + values = list(outputs.values()) + return str(values[0]) if values else "" diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 9f511b88ef..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,56 +1,17 @@ -import logging -from dataclasses import dataclass from enum import StrEnum, auto -logger = logging.getLogger(__name__) - - -@dataclass -class QuotaCharge: - """ - Result of a quota consumption operation. - - Attributes: - success: Whether the quota charge succeeded - charge_id: UUID for refund, or None if failed/disabled - """ - - success: bool - charge_id: str | None - _quota_type: "QuotaType" - - def refund(self) -> None: - """ - Refund this quota charge. - - Safe to call even if charge failed or was disabled. - This method guarantees no exceptions will be raised. - """ - if self.charge_id: - self._quota_type.refund(self.charge_id) - logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id) - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. - - Add additional types here whenever new billable features become available. """ - # Trigger execution quota TRIGGER = auto() - - # Workflow execution quota WORKFLOW = auto() - UNLIMITED = auto() @property def billing_key(self) -> str: - """ - Get the billing key for the feature. - """ match self: case QuotaType.TRIGGER: return "trigger_event" @@ -58,152 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: - """ - Consume quota for the feature. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with success status and charge_id for refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to consume must be greater than 0") - - try: - response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount) - - if response.get("result") != "success": - logger.warning( - "Failed to consume quota for %s, feature %s details: %s", - tenant_id, - self.value, - response.get("detail"), - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - charge_id = response.get("history_id") - logger.debug( - "Successfully consumed %d %s quota for tenant %s, charge_id: %s", - amount, - self.value, - tenant_id, - charge_id, - ) - return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self) - - except QuotaExceededError: - raise - except Exception: - # fail-safe: allow request on billing errors - logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - """ - Check if tenant has sufficient quota without consuming. - - Args: - tenant_id: The tenant identifier - amount: Amount to check (default: 1) - - Returns: - True if quota is sufficient, False otherwise - """ - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - # fail-safe: allow request on billing errors - return True - - def refund(self, charge_id: str) -> None: - """ - Refund quota using charge_id from consume(). - - This method guarantees no exceptions will be raised. - All errors are logged but silently handled. - - Args: - charge_id: The UUID returned from consume() - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not charge_id: - logger.warning("Cannot refund: charge_id is empty") - return - - logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id) - - response = BillingService.refund_tenant_feature_plan_usage(charge_id) - if response.get("result") == "success": - logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id) - else: - logger.warning("Refund failed for charge_id: %s", charge_id) - - except Exception: - # Catch ALL exceptions - refund must never fail - logger.exception("Failed to refund quota for charge_id: %s", charge_id) - # Don't raise - refund is best-effort and must be silent - - def get_remaining(self, tenant_id: str) -> int: - """ - Get remaining quota for the tenant. - - Args: - tenant_id: The tenant identifier - - Returns: - Remaining quota amount - """ - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key) - # Assuming the API returns a dict with 'remaining' or 'limit' and 'used' - if isinstance(usage_info, dict): - return usage_info.get("remaining", 0) - # If it returns a simple number, treat it as remaining - return int(usage_info) if usage_info else 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 - - -def unlimited() -> QuotaCharge: - """ - Return a quota charge for unlimited quota. - - This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type. - """ - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/fields/snippet_fields.py b/api/fields/snippet_fields.py new file mode 100644 index 0000000000..2562d7cff0 --- /dev/null +++ b/api/fields/snippet_fields.py @@ -0,0 +1,45 @@ +from flask_restx import fields + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +# Snippet list item fields (lightweight for list display) +snippet_list_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "version": fields.Integer, + "use_count": fields.Integer, + "is_published": fields.Boolean, + "icon_info": fields.Raw, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +# Full snippet fields (includes creator info and graph data) +snippet_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "version": fields.Integer, + "use_count": fields.Integer, + "is_published": fields.Boolean, + "icon_info": fields.Raw, + "graph": fields.Raw(attribute="graph_dict"), + "input_fields": fields.Raw(attribute="input_fields_list"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, +} + +# Pagination response fields +snippet_pagination_fields = { + "data": fields.List(fields.Nested(snippet_list_fields)), + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, +} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index d0e762f62b..195b720285 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -14,6 +14,7 @@ workflow_app_log_partial_fields = { "id": fields.String, "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), "details": fields.Raw(attribute="details"), + "evaluation": fields.Raw(attribute="evaluation", default=None), "created_from": fields.String, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), diff --git a/api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py b/api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py new file mode 100644 index 0000000000..5d487b121b --- /dev/null +++ b/api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py @@ -0,0 +1,83 @@ +"""add_customized_snippets_table + +Revision ID: 1c05e80d2380 +Revises: 788d3099ae3a +Create Date: 2026-01-29 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +import models as models + + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + + +# revision identifiers, used by Alembic. +revision = "1c05e80d2380" +down_revision = "788d3099ae3a" +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + + if _is_pg(conn): + op.create_table( + "customized_snippets", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False), + sa.Column("workflow_id", models.types.StringUUID(), nullable=True), + sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False), + sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False), + sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("graph", sa.Text(), nullable=True), + sa.Column("input_fields", sa.Text(), nullable=True), + sa.Column("created_by", models.types.StringUUID(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_by", models.types.StringUUID(), nullable=True), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"), + sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"), + ) + else: + op.create_table( + "customized_snippets", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", models.types.LongText(), nullable=True), + sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False), + sa.Column("workflow_id", models.types.StringUUID(), nullable=True), + sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False), + sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False), + sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True), + sa.Column("graph", models.types.LongText(), nullable=True), + sa.Column("input_fields", models.types.LongText(), nullable=True), + sa.Column("created_by", models.types.StringUUID(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.Column("updated_by", models.types.StringUUID(), nullable=True), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"), + sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"), + ) + + with op.batch_alter_table("customized_snippets", schema=None) as batch_op: + batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False) + + +def downgrade(): + with op.batch_alter_table("customized_snippets", schema=None) as batch_op: + batch_op.drop_index("customized_snippet_tenant_idx") + + op.drop_table("customized_snippets") diff --git a/api/migrations/versions/2026_03_03_0001-a1b2c3d4e5f6_add_evaluation_tables.py b/api/migrations/versions/2026_03_03_0001-a1b2c3d4e5f6_add_evaluation_tables.py new file mode 100644 index 0000000000..986b4ca8bd --- /dev/null +++ b/api/migrations/versions/2026_03_03_0001-a1b2c3d4e5f6_add_evaluation_tables.py @@ -0,0 +1,116 @@ +"""add_evaluation_tables + +Revision ID: a1b2c3d4e5f6 +Revises: 1c05e80d2380 +Create Date: 2026-03-03 00:01:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + + +# revision identifiers, used by Alembic. +revision = "a1b2c3d4e5f6" +down_revision = "1c05e80d2380" +branch_labels = None +depends_on = None + + +def upgrade(): + # evaluation_configurations + op.create_table( + "evaluation_configurations", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("target_type", sa.String(length=20), nullable=False), + sa.Column("target_id", models.types.StringUUID(), nullable=False), + sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True), + sa.Column("evaluation_model", sa.String(length=255), nullable=True), + sa.Column("metrics_config", models.types.LongText(), nullable=True), + sa.Column("judgement_conditions", models.types.LongText(), nullable=True), + sa.Column("created_by", models.types.StringUUID(), nullable=False), + sa.Column("updated_by", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"), + sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"), + ) + with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op: + batch_op.create_index( + "evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False + ) + + # evaluation_runs + op.create_table( + "evaluation_runs", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("target_type", sa.String(length=20), nullable=False), + sa.Column("target_id", models.types.StringUUID(), nullable=False), + sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")), + sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True), + sa.Column("result_file_id", models.types.StringUUID(), nullable=True), + sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")), + sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")), + sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")), + sa.Column("metrics_summary", models.types.LongText(), nullable=True), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("celery_task_id", sa.String(length=255), nullable=True), + sa.Column("created_by", models.types.StringUUID(), nullable=False), + sa.Column("started_at", sa.DateTime(), nullable=True), + sa.Column("completed_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"), + ) + with op.batch_alter_table("evaluation_runs", schema=None) as batch_op: + batch_op.create_index( + "evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False + ) + batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False) + + # evaluation_run_items + op.create_table( + "evaluation_run_items", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False), + sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True), + sa.Column("item_index", sa.Integer(), nullable=False), + sa.Column("inputs", models.types.LongText(), nullable=True), + sa.Column("expected_output", models.types.LongText(), nullable=True), + sa.Column("context", models.types.LongText(), nullable=True), + sa.Column("actual_output", models.types.LongText(), nullable=True), + sa.Column("metrics", models.types.LongText(), nullable=True), + sa.Column("metadata_json", models.types.LongText(), nullable=True), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("overall_score", sa.Float(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"), + ) + with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op: + batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False) + batch_op.create_index( + "evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False + ) + batch_op.create_index("evaluation_run_item_workflow_run_idx", ["workflow_run_id"], unique=False) + + +def downgrade(): + with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op: + batch_op.drop_index("evaluation_run_item_workflow_run_idx") + batch_op.drop_index("evaluation_run_item_index_idx") + batch_op.drop_index("evaluation_run_item_run_idx") + op.drop_table("evaluation_run_items") + + with op.batch_alter_table("evaluation_runs", schema=None) as batch_op: + batch_op.drop_index("evaluation_run_status_idx") + batch_op.drop_index("evaluation_run_target_idx") + op.drop_table("evaluation_runs") + + with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op: + batch_op.drop_index("evaluation_configuration_target_idx") + op.drop_table("evaluation_configurations") diff --git a/api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py b/api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py new file mode 100644 index 0000000000..4b1b32c8d3 --- /dev/null +++ b/api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py @@ -0,0 +1,25 @@ +"""merge migration heads + +Revision ID: 4c60d8d3ee74 +Revises: fce013ca180e, a1b2c3d4e5f6 +Create Date: 2026-03-17 17:21:12.105536 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4c60d8d3ee74' +down_revision = ('fce013ca180e', 'a1b2c3d4e5f6') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/api/models/__init__.py b/api/models/__init__.py index fcae07f948..8a605d5195 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -33,6 +33,13 @@ from .enums import ( WorkflowRunTriggeredFrom, WorkflowTriggerStatus, ) +from .evaluation import ( + EvaluationConfiguration, + EvaluationRun, + EvaluationRunItem, + EvaluationRunStatus, + EvaluationTargetType, +) from .execution_extra_content import ExecutionExtraContent, HumanInputContent from .human_input import HumanInputForm from .model import ( @@ -80,6 +87,7 @@ from .provider import ( TenantDefaultModel, TenantPreferredModelProvider, ) +from .snippet import CustomizedSnippet, SnippetType from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from .task import CeleryTask, CeleryTaskSet from .tools import ( @@ -139,6 +147,7 @@ __all__ = [ "Conversation", "ConversationVariable", "CreatorUserRole", + "CustomizedSnippet", "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", @@ -156,6 +165,11 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", + "EvaluationConfiguration", + "EvaluationRun", + "EvaluationRunItem", + "EvaluationRunStatus", + "EvaluationTargetType", "ExecutionExtraContent", "ExporleBanner", "ExternalKnowledgeApis", @@ -183,6 +197,7 @@ __all__ = [ "RecommendedApp", "SavedMessage", "Site", + "SnippetType", "Tag", "TagBinding", "Tenant", diff --git a/api/models/evaluation.py b/api/models/evaluation.py new file mode 100644 index 0000000000..680d6ab31c --- /dev/null +++ b/api/models/evaluation.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import json +from datetime import datetime +from enum import StrEnum +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import DateTime, Float, Integer, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column + +from libs.uuid_utils import uuidv7 + +from .base import Base +from .types import LongText, StringUUID + + +class EvaluationRunStatus(StrEnum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class EvaluationTargetType(StrEnum): + APP = "app" + SNIPPETS = "snippets" + KNOWLEDGE_BASE = "knowledge_base" + + +class EvaluationConfiguration(Base): + """Stores evaluation configuration for each target (App or Snippet).""" + + __tablename__ = "evaluation_configurations" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"), + sa.Index("evaluation_configuration_target_idx", "tenant_id", "target_type", "target_id"), + sa.Index("evaluation_configuration_workflow_idx", "customized_workflow_id"), + sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + target_type: Mapped[str] = mapped_column(String(20), nullable=False) + target_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + evaluation_model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True) + evaluation_model: Mapped[str | None] = mapped_column(String(255), nullable=True) + metrics_config: Mapped[str | None] = mapped_column(LongText, nullable=True) + judgement_conditions: Mapped[str | None] = mapped_column(LongText, nullable=True) + customized_workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + updated_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + @property + def metrics_config_dict(self) -> dict[str, Any]: + if self.metrics_config: + return json.loads(self.metrics_config) + return {} + + @metrics_config_dict.setter + def metrics_config_dict(self, value: dict[str, Any]) -> None: + self.metrics_config = json.dumps(value) + + @property + def default_metrics_list(self) -> list[dict[str, Any]]: + """Extract default_metrics from the stored metrics_config JSON.""" + config = self.metrics_config_dict + return config.get("default_metrics", []) + + @property + def customized_metrics_dict(self) -> dict[str, Any] | None: + """Extract customized_metrics from the stored metrics_config JSON.""" + config = self.metrics_config_dict + return config.get("customized_metrics") + + @property + def judgment_config_dict(self) -> dict[str, Any] | None: + """Return judgment config (stored in the judgement_conditions column).""" + if self.judgement_conditions: + parsed = json.loads(self.judgement_conditions) + return parsed if parsed else None + return None + + @property + def judgement_conditions_dict(self) -> dict[str, Any]: + if self.judgement_conditions: + return json.loads(self.judgement_conditions) + return {} + + @judgement_conditions_dict.setter + def judgement_conditions_dict(self, value: dict[str, Any]) -> None: + self.judgement_conditions = json.dumps(value) + + def __repr__(self) -> str: + return f"" + + +class EvaluationRun(Base): + """Stores each evaluation run record.""" + + __tablename__ = "evaluation_runs" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"), + sa.Index("evaluation_run_target_idx", "tenant_id", "target_type", "target_id"), + sa.Index("evaluation_run_status_idx", "tenant_id", "status"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + target_type: Mapped[str] = mapped_column(String(20), nullable=False) + target_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + evaluation_config_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + status: Mapped[str] = mapped_column(String(20), nullable=False, default=EvaluationRunStatus.PENDING) + dataset_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + result_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + total_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + completed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + failed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error: Mapped[str | None] = mapped_column(Text, nullable=True) + + celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + @property + def progress(self) -> float: + if self.total_items == 0: + return 0.0 + return (self.completed_items + self.failed_items) / self.total_items + + def __repr__(self) -> str: + return f"" + + +class EvaluationRunItem(Base): + """Stores per-row evaluation results.""" + + __tablename__ = "evaluation_run_items" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"), + sa.Index("evaluation_run_item_run_idx", "evaluation_run_id"), + sa.Index("evaluation_run_item_index_idx", "evaluation_run_id", "item_index"), + sa.Index("evaluation_run_item_workflow_run_idx", "workflow_run_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + evaluation_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + item_index: Mapped[int] = mapped_column(Integer, nullable=False) + inputs: Mapped[str | None] = mapped_column(LongText, nullable=True) + expected_output: Mapped[str | None] = mapped_column(LongText, nullable=True) + context: Mapped[str | None] = mapped_column(LongText, nullable=True) + actual_output: Mapped[str | None] = mapped_column(LongText, nullable=True) + + metrics: Mapped[str | None] = mapped_column(LongText, nullable=True) + judgment: Mapped[str | None] = mapped_column(LongText, nullable=True) + metadata_json: Mapped[str | None] = mapped_column(LongText, nullable=True) + error: Mapped[str | None] = mapped_column(Text, nullable=True) + + overall_score: Mapped[float | None] = mapped_column(Float, nullable=True) + + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def inputs_dict(self) -> dict[str, Any]: + if self.inputs: + return json.loads(self.inputs) + return {} + + @property + def metrics_list(self) -> list[dict[str, Any]]: + if self.metrics: + return json.loads(self.metrics) + return [] + + @property + def judgment_dict(self) -> dict[str, Any]: + if self.judgment: + return json.loads(self.judgment) + return {} + + @property + def metadata_dict(self) -> dict[str, Any]: + if self.metadata_json: + return json.loads(self.metadata_json) + return {} + + def __repr__(self) -> str: + return f"" diff --git a/api/models/snippet.py b/api/models/snippet.py new file mode 100644 index 0000000000..4a4645d744 --- /dev/null +++ b/api/models/snippet.py @@ -0,0 +1,101 @@ +import json +from datetime import datetime +from enum import StrEnum +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func +from sqlalchemy.orm import Mapped, mapped_column + +from libs.uuid_utils import uuidv7 + +from .account import Account +from .base import Base +from .engine import db +from .types import AdjustedJSON, LongText, StringUUID + + +class SnippetType(StrEnum): + """Snippet Type Enum""" + + NODE = "node" + GROUP = "group" + + +class CustomizedSnippet(Base): + """ + Customized Snippet Model + + Stores reusable workflow components (nodes or node groups) that can be + shared across applications within a workspace. + """ + + __tablename__ = "customized_snippets" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"), + sa.Index("customized_snippet_tenant_idx", "tenant_id"), + sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str | None] = mapped_column(LongText, nullable=True) + type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'")) + + # Workflow reference for published version + workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + # State flags + is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1")) + use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + + # Visual customization + icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True) + + # Snippet configuration (stored as JSON text) + input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True) + + # Audit fields + created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + @property + def graph_dict(self) -> dict[str, Any]: + """Get graph from associated workflow.""" + if self.workflow_id: + from .workflow import Workflow + + workflow = db.session.get(Workflow, self.workflow_id) + if workflow: + return json.loads(workflow.graph) if workflow.graph else {} + return {} + + @property + def input_fields_list(self) -> list[dict[str, Any]]: + """Parse input_fields JSON to list.""" + return json.loads(self.input_fields) if self.input_fields else [] + + @property + def created_by_account(self) -> Account | None: + """Get the account that created this snippet.""" + if self.created_by: + return db.session.get(Account, self.created_by) + return None + + @property + def updated_by_account(self) -> Account | None: + """Get the account that last updated this snippet.""" + if self.updated_by: + return db.session.get(Account, self.updated_by) + return None + + @property + def version_str(self) -> str: + """Get version as string for API response.""" + return str(self.version) diff --git a/api/models/workflow.py b/api/models/workflow.py index 63abf8c3b6..aadb062426 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -106,6 +106,8 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" RAG_PIPELINE = "rag-pipeline" + SNIPPET = "snippet" + EVALUATION = "evaluation" @classmethod def value_of(cls, value: str) -> "WorkflowType": diff --git a/api/pyproject.toml b/api/pyproject.toml index f22bafb03a..abef591783 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -234,6 +234,12 @@ storage = [ ############################################################ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] +############################################################ +# [ Evaluation ] dependency group +# Required for evaluation frameworks +############################################################ +evaluation = ["ragas>=0.2.0", "deepeval>=2.0.0"] + ############################################################ # [ VDB ] workspace plugins — hollow packages under providers/vdb/* # Each declares its own third-party deps and registers dify.vector_backends entry points. diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 5e8c7aa337..2c9d815b64 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -106,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") @@ -116,6 +117,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) + quota_charge.commit() effective_mode = ( AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode ) diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index a731d5c048..8b39d63385 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -131,9 +132,10 @@ class AsyncWorkflowService: trigger_log = trigger_log_repo.create(trigger_log) session.commit() - # 7. Check and consume quota + # 7. Reserve quota (commit after successful dispatch) + quota_charge = unlimited() try: - QuotaType.WORKFLOW.consume(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED @@ -153,13 +155,18 @@ class AsyncWorkflowService: # 9. Dispatch to appropriate queue task_data_dict = task_data.model_dump(mode="json") - task: AsyncResult[Any] | None = None - if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) - elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) - else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) + try: + task: AsyncResult[Any] | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED diff --git a/api/services/billing_service.py b/api/services/billing_service.py index a1362ccad6..eeaddfee2f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -32,6 +32,102 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) +class _BillingQuota(TypedDict): + size: int + limit: int + + +class _VectorSpaceQuota(TypedDict): + size: float + limit: int + + +class _KnowledgeRateLimit(TypedDict): + # NOTE (hj24): + # 1. Return for sandbox users but is null for other plans, it's defined but never used. + # 2. Keep it for compatibility for now, can be deprecated in future versions. + size: NotRequired[int] + # NOTE END + limit: int + + +class _BillingSubscription(TypedDict): + plan: str + interval: str + education: bool + + +class BillingInfo(TypedDict): + """Response of /subscription/info. + + NOTE (hj24): + - Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python() + - To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter: + 1. validate_python in non-strict mode will coerce it to the expected type + 2. In strict mode, it will raise ValidationError + 3. To preserve compatibility, always keep non-strict mode here and avoid strict mode + """ + + enabled: bool + subscription: _BillingSubscription + members: _BillingQuota + apps: _BillingQuota + vector_space: _VectorSpaceQuota + knowledge_rate_limit: _KnowledgeRateLimit + documents_upload_quota: _BillingQuota + annotation_quota_limit: _BillingQuota + docs_processing: str + can_replace_logo: bool + model_load_balancing_enabled: bool + knowledge_pipeline_publish_enabled: bool + next_credit_reset_date: NotRequired[int] + + +_billing_info_adapter = TypeAdapter(BillingInfo) + + +class _TenantFeatureQuota(TypedDict): + usage: int + limit: int + reset_date: NotRequired[int] + + +class TenantFeatureQuotaInfo(TypedDict): + """Response of /quota/info. + + NOTE (hj24): + - Same convention as BillingInfo: billing may return int fields as str, + always keep non-strict mode to auto-coerce. + """ + + trigger_event: _TenantFeatureQuota + api_rate_limit: _TenantFeatureQuota + + +_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo) + + class _BillingQuota(TypedDict): size: int limit: int @@ -149,11 +245,63 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info + @classmethod + def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo: + params = {"tenant_id": tenant_id} + return _tenant_feature_quota_info_adapter.validate_python( + cls._send_request("GET", "/quota/info", params=params) + ) + + @classmethod + def quota_reserve( + cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "request_id": request_id, + "amount": amount, + } + if meta: + payload["meta"] = meta + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) + + @classmethod + def quota_commit( + cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + "actual_amount": actual_amount, + } + if meta: + payload["meta"] = meta + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) + + @classmethod + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) + ) + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} diff --git a/api/services/errors/evaluation.py b/api/services/errors/evaluation.py new file mode 100644 index 0000000000..6affb68d21 --- /dev/null +++ b/api/services/errors/evaluation.py @@ -0,0 +1,21 @@ +from services.errors.base import BaseServiceError + + +class EvaluationFrameworkNotConfiguredError(BaseServiceError): + def __init__(self, description: str | None = None): + super().__init__(description or "Evaluation framework is not configured. Set EVALUATION_FRAMEWORK env var.") + + +class EvaluationNotFoundError(BaseServiceError): + def __init__(self, description: str | None = None): + super().__init__(description or "Evaluation not found.") + + +class EvaluationDatasetInvalidError(BaseServiceError): + def __init__(self, description: str | None = None): + super().__init__(description or "Evaluation dataset is invalid.") + + +class EvaluationMaxConcurrentRunsError(BaseServiceError): + def __init__(self, description: str | None = None): + super().__init__(description or "Maximum number of concurrent evaluation runs reached.") diff --git a/api/services/evaluation_service.py b/api/services/evaluation_service.py new file mode 100644 index 0000000000..196af2a617 --- /dev/null +++ b/api/services/evaluation_service.py @@ -0,0 +1,985 @@ +import io +import json +import logging +from collections.abc import Mapping +from typing import Any, Union + +from openpyxl import Workbook, load_workbook +from openpyxl.styles import Alignment, Border, Font, PatternFill, Side +from openpyxl.utils import get_column_letter +from sqlalchemy.orm import Session + +from configs import dify_config +from core.evaluation.entities.evaluation_entity import ( + METRIC_NODE_TYPE_MAPPING, + METRIC_VALUE_TYPE_MAPPING, + DefaultMetric, + EvaluationCategory, + EvaluationConfigData, + EvaluationDatasetInput, + EvaluationMetricName, + EvaluationRunData, + EvaluationRunRequest, + NodeInfo, +) +from core.evaluation.evaluation_manager import EvaluationManager +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.node_events.base import NodeRunResult +from models.evaluation import ( + EvaluationConfiguration, + EvaluationRun, + EvaluationRunItem, + EvaluationRunStatus, +) +from models.model import App, AppMode +from models.snippet import CustomizedSnippet +from models.workflow import Workflow +from services.errors.evaluation import ( + EvaluationDatasetInvalidError, + EvaluationFrameworkNotConfiguredError, + EvaluationMaxConcurrentRunsError, + EvaluationNotFoundError, +) +from services.snippet_service import SnippetService +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + + +class EvaluationService: + """ + Service for evaluation-related operations. + + Provides functionality to generate evaluation dataset templates + based on App or Snippet input parameters. + """ + + # Excluded app modes that don't support evaluation templates + EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE} + + @classmethod + def generate_dataset_template( + cls, + target: Union[App, CustomizedSnippet], + target_type: str, + ) -> tuple[bytes, str]: + """ + Generate evaluation dataset template as XLSX bytes. + + Creates an XLSX file with headers based on the evaluation target's input parameters. + The first column is index, followed by input parameter columns. + + :param target: App or CustomizedSnippet instance + :param target_type: Target type string ("app" or "snippet") + :return: Tuple of (xlsx_content_bytes, filename) + :raises ValueError: If target type is not supported or app mode is excluded + """ + # Validate target type + if target_type == "app": + if not isinstance(target, App): + raise ValueError("Invalid target: expected App instance") + if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES: + raise ValueError(f"App mode '{target.mode}' does not support evaluation templates") + input_fields = cls._get_app_input_fields(target) + elif target_type == "snippet": + if not isinstance(target, CustomizedSnippet): + raise ValueError("Invalid target: expected CustomizedSnippet instance") + input_fields = cls._get_snippet_input_fields(target) + else: + raise ValueError(f"Unsupported target type: {target_type}") + + # Generate XLSX template + xlsx_content = cls._generate_xlsx_template(input_fields, target.name) + + # Build filename + truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name + filename = f"{truncated_name}-evaluation-dataset.xlsx" + + return xlsx_content, filename + + @classmethod + def _get_app_input_fields(cls, app: App) -> list[dict]: + """ + Get input fields from App's workflow. + + :param app: App instance + :return: List of input field definitions + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_published_workflow(app_model=app) + if not workflow: + workflow = workflow_service.get_draft_workflow(app_model=app) + + if not workflow: + return [] + + # Get user input form from workflow + user_input_form = workflow.user_input_form() + return user_input_form + + @classmethod + def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]: + """ + Get input fields from Snippet. + + Tries to get from snippet's own input_fields first, + then falls back to workflow's user_input_form. + + :param snippet: CustomizedSnippet instance + :return: List of input field definitions + """ + # Try snippet's own input_fields first + input_fields = snippet.input_fields_list + if input_fields: + return input_fields + + # Fallback to workflow's user_input_form + snippet_service = SnippetService() + workflow = snippet_service.get_published_workflow(snippet=snippet) + if not workflow: + workflow = snippet_service.get_draft_workflow(snippet=snippet) + + if workflow: + return workflow.user_input_form() + + return [] + + @classmethod + def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes: + """ + Generate XLSX template file content. + + Creates a workbook with: + - First row as header row with "index" and input field names + - Styled header with background color and borders + - Empty data rows ready for user input + + :param input_fields: List of input field definitions + :param target_name: Name of the target (for sheet name) + :return: XLSX file content as bytes + """ + wb = Workbook() + ws = wb.active + if ws is None: + ws = wb.create_sheet("Evaluation Dataset") + + sheet_name = "Evaluation Dataset" + ws.title = sheet_name + + header_font = Font(bold=True, color="FFFFFF") + header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") + header_alignment = Alignment(horizontal="center", vertical="center") + thin_border = Border( + left=Side(style="thin"), + right=Side(style="thin"), + top=Side(style="thin"), + bottom=Side(style="thin"), + ) + + # Build header row + headers = ["index"] + + for field in input_fields: + field_label = str(field.get("label") or field.get("variable") or "") + headers.append(field_label) + + # Write header row + for col_idx, header in enumerate(headers, start=1): + cell = ws.cell(row=1, column=col_idx, value=header) + cell.font = header_font + cell.fill = header_fill + cell.alignment = header_alignment + cell.border = thin_border + + # Set column widths + ws.column_dimensions["A"].width = 10 # index column + for col_idx in range(2, len(headers) + 1): + ws.column_dimensions[get_column_letter(col_idx)].width = 20 + + # Add one empty row with row number for user reference + for col_idx in range(1, len(headers) + 1): + cell = ws.cell(row=2, column=col_idx, value="") + cell.border = thin_border + if col_idx == 1: + cell.value = 1 + cell.alignment = Alignment(horizontal="center") + + # Save to bytes + output = io.BytesIO() + wb.save(output) + output.seek(0) + + return output.getvalue() + + @classmethod + def generate_retrieval_dataset_template(cls) -> tuple[bytes, str]: + """Generate evaluation dataset XLSX template for knowledge base retrieval. + + The template contains three columns: ``index``, ``query``, and + ``expected_output``. Callers upload a filled copy and start an + evaluation run with ``target_type="dataset"``. + + :returns: (xlsx_content_bytes, filename) + """ + wb = Workbook() + ws = wb.active + if ws is None: + ws = wb.create_sheet("Evaluation Dataset") + ws.title = "Evaluation Dataset" + + header_font = Font(bold=True, color="FFFFFF") + header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") + header_alignment = Alignment(horizontal="center", vertical="center") + thin_border = Border( + left=Side(style="thin"), + right=Side(style="thin"), + top=Side(style="thin"), + bottom=Side(style="thin"), + ) + + headers = ["index", "query", "expected_output"] + for col_idx, header in enumerate(headers, start=1): + cell = ws.cell(row=1, column=col_idx, value=header) + cell.font = header_font + cell.fill = header_fill + cell.alignment = header_alignment + cell.border = thin_border + + ws.column_dimensions["A"].width = 10 + ws.column_dimensions["B"].width = 30 + ws.column_dimensions["C"].width = 30 + + # Add one sample row + for col_idx in range(1, len(headers) + 1): + cell = ws.cell(row=2, column=col_idx, value="") + cell.border = thin_border + if col_idx == 1: + cell.value = 1 + cell.alignment = Alignment(horizontal="center") + + output = io.BytesIO() + wb.save(output) + output.seek(0) + return output.getvalue(), "retrieval-evaluation-dataset.xlsx" + + # ---- Evaluation Configuration CRUD ---- + + @classmethod + def get_evaluation_config( + cls, + session: Session, + tenant_id: str, + target_type: str, + target_id: str, + ) -> EvaluationConfiguration | None: + return ( + session.query(EvaluationConfiguration) + .filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id) + .first() + ) + + @classmethod + def save_evaluation_config( + cls, + session: Session, + tenant_id: str, + target_type: str, + target_id: str, + account_id: str, + data: EvaluationConfigData, + ) -> EvaluationConfiguration: + config = cls.get_evaluation_config(session, tenant_id, target_type, target_id) + if config is None: + config = EvaluationConfiguration( + tenant_id=tenant_id, + target_type=target_type, + target_id=target_id, + created_by=account_id, + updated_by=account_id, + ) + session.add(config) + + config.evaluation_model_provider = data.evaluation_model_provider + config.evaluation_model = data.evaluation_model + config.metrics_config = json.dumps( + { + "default_metrics": [m.model_dump() for m in data.default_metrics], + "customized_metrics": data.customized_metrics.model_dump() if data.customized_metrics else None, + } + ) + config.judgement_conditions = json.dumps(data.judgment_config.model_dump() if data.judgment_config else {}) + config.customized_workflow_id = ( + data.customized_metrics.evaluation_workflow_id if data.customized_metrics else None + ) + config.updated_by = account_id + session.commit() + session.refresh(config) + return config + + @classmethod + def list_targets_by_customized_workflow( + cls, + session: Session, + tenant_id: str, + customized_workflow_id: str, + ) -> list[EvaluationConfiguration]: + """Return all evaluation configs that reference the given workflow as customized metrics.""" + from sqlalchemy import select + + return list( + session.scalars( + select(EvaluationConfiguration).where( + EvaluationConfiguration.tenant_id == tenant_id, + EvaluationConfiguration.customized_workflow_id == customized_workflow_id, + ) + ).all() + ) + + # ---- Evaluation Run Management ---- + + @classmethod + def start_evaluation_run( + cls, + session: Session, + tenant_id: str, + target_type: str, + target_id: str, + account_id: str, + dataset_file_content: bytes, + run_request: EvaluationRunRequest, + ) -> EvaluationRun: + """Validate dataset, create run record, dispatch Celery task. + + Saves the provided parameters as the latest EvaluationConfiguration + before creating the run. + """ + # Check framework is configured + evaluation_instance = EvaluationManager.get_evaluation_instance() + if evaluation_instance is None: + raise EvaluationFrameworkNotConfiguredError() + + # Save as latest EvaluationConfiguration + config = cls.save_evaluation_config( + session=session, + tenant_id=tenant_id, + target_type=target_type, + target_id=target_id, + account_id=account_id, + data=run_request, + ) + + # Check concurrent run limit + active_runs = ( + session.query(EvaluationRun) + .filter_by(tenant_id=tenant_id) + .filter(EvaluationRun.status.in_([EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING])) + .count() + ) + max_concurrent = dify_config.EVALUATION_MAX_CONCURRENT_RUNS + if active_runs >= max_concurrent: + raise EvaluationMaxConcurrentRunsError(f"Maximum concurrent runs ({max_concurrent}) reached.") + + # Parse dataset + items = cls._parse_dataset(dataset_file_content) + max_rows = dify_config.EVALUATION_MAX_DATASET_ROWS + if len(items) > max_rows: + raise EvaluationDatasetInvalidError(f"Dataset has {len(items)} rows, max is {max_rows}.") + + # Create evaluation run + evaluation_run = EvaluationRun( + tenant_id=tenant_id, + target_type=target_type, + target_id=target_id, + evaluation_config_id=config.id, + status=EvaluationRunStatus.PENDING, + total_items=len(items), + created_by=account_id, + ) + session.add(evaluation_run) + session.commit() + session.refresh(evaluation_run) + + # Build Celery task data + run_data = EvaluationRunData( + evaluation_run_id=evaluation_run.id, + tenant_id=tenant_id, + target_type=target_type, + target_id=target_id, + evaluation_model_provider=run_request.evaluation_model_provider, + evaluation_model=run_request.evaluation_model, + default_metrics=run_request.default_metrics, + customized_metrics=run_request.customized_metrics, + judgment_config=run_request.judgment_config, + input_list=items, + ) + + # Dispatch Celery task + from tasks.evaluation_task import run_evaluation + + task = run_evaluation.delay(run_data.model_dump()) + evaluation_run.celery_task_id = task.id + session.commit() + + return evaluation_run + + @classmethod + def get_evaluation_runs( + cls, + session: Session, + tenant_id: str, + target_type: str, + target_id: str, + page: int = 1, + page_size: int = 20, + ) -> tuple[list[EvaluationRun], int]: + """Query evaluation run history with pagination.""" + query = ( + session.query(EvaluationRun) + .filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id) + .order_by(EvaluationRun.created_at.desc()) + ) + total = query.count() + runs = query.offset((page - 1) * page_size).limit(page_size).all() + return runs, total + + @classmethod + def get_evaluation_run_detail( + cls, + session: Session, + tenant_id: str, + run_id: str, + ) -> EvaluationRun: + run = session.query(EvaluationRun).filter_by(id=run_id, tenant_id=tenant_id).first() + if not run: + raise EvaluationNotFoundError("Evaluation run not found.") + return run + + @classmethod + def get_evaluation_run_items( + cls, + session: Session, + run_id: str, + page: int = 1, + page_size: int = 50, + ) -> tuple[list[EvaluationRunItem], int]: + """Query evaluation run items with pagination.""" + query = ( + session.query(EvaluationRunItem) + .filter_by(evaluation_run_id=run_id) + .order_by(EvaluationRunItem.item_index.asc()) + ) + total = query.count() + items = query.offset((page - 1) * page_size).limit(page_size).all() + return items, total + + @classmethod + def cancel_evaluation_run( + cls, + session: Session, + tenant_id: str, + run_id: str, + ) -> EvaluationRun: + run = cls.get_evaluation_run_detail(session, tenant_id, run_id) + if run.status not in (EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING): + raise ValueError(f"Cannot cancel evaluation run in status: {run.status}") + + run.status = EvaluationRunStatus.CANCELLED + + # Revoke Celery task if running + if run.celery_task_id: + try: + from celery import current_app as celery_app + + celery_app.control.revoke(run.celery_task_id, terminate=True) + except Exception: + logger.exception("Failed to revoke Celery task %s", run.celery_task_id) + + session.commit() + return run + + @classmethod + def get_supported_metrics(cls, category: EvaluationCategory) -> list[str]: + return EvaluationManager.get_supported_metrics(category) + + @staticmethod + def get_available_metrics() -> list[str]: + """Return the centrally-defined list of evaluation metrics.""" + return [m.value for m in EvaluationMetricName] + + @classmethod + def _nodes_for_metrics_from_workflow( + cls, + workflow: Workflow, + metrics: list[str], + ) -> dict[str, list[dict[str, str]]]: + node_type_to_nodes: dict[str, list[dict[str, str]]] = {} + for node_id, node_data in workflow.walk_nodes(): + ntype = node_data.get("type", "") + node_type_to_nodes.setdefault(ntype, []).append( + NodeInfo(node_id=node_id, type=ntype, title=node_data.get("title", "")).model_dump() + ) + + result: dict[str, list[dict[str, str]]] = {} + for metric in metrics: + required_node_type = METRIC_NODE_TYPE_MAPPING.get(metric) + if required_node_type is None: + result[metric] = [] + continue + result[metric] = node_type_to_nodes.get(required_node_type, []) + return result + + @classmethod + def _union_supported_metric_names(cls) -> list[str]: + """Metric names the current evaluation framework supports for any :class:`EvaluationCategory`.""" + ordered: list[str] = [] + seen: set[str] = set() + for category in EvaluationCategory: + for name in cls.get_supported_metrics(category): + if name not in seen: + seen.add(name) + ordered.append(name) + return ordered + + @classmethod + def get_default_metrics_with_nodes_for_published_target( + cls, + target: Union[App, CustomizedSnippet], + target_type: str, + ) -> list[DefaultMetric]: + """List default metrics and matching nodes using only the *published* workflow graph. + + Metrics are those supported by the configured evaluation framework and present in + :data:`METRIC_NODE_TYPE_MAPPING`. Node lists are derived from the published workflow only + (no draft fallback). + """ + workflow = cls._resolve_published_workflow(target, target_type) + if not workflow: + return [] + + supported = cls._union_supported_metric_names() + metric_names = sorted(m for m in supported if m in METRIC_NODE_TYPE_MAPPING) + if not metric_names: + return [] + + nodes_by_metric = cls._nodes_for_metrics_from_workflow(workflow, metric_names) + return [ + DefaultMetric( + metric=m, + value_type=METRIC_VALUE_TYPE_MAPPING.get(m, "number"), + node_info_list=[NodeInfo.model_validate(n) for n in nodes_by_metric.get(m, [])], + ) + for m in metric_names + ] + + @classmethod + def get_nodes_for_metrics( + cls, + target: Union[App, CustomizedSnippet], + target_type: str, + metrics: list[str] | None = None, + ) -> dict[str, list[dict[str, str]]]: + """Return node info grouped by metric (or all nodes when *metrics* is empty). + + :param target: App or CustomizedSnippet instance. + :param target_type: ``"app"`` or ``"snippets"``. + :param metrics: Optional list of metric names to filter by. + When *None* or empty, returns ``{"all": []}``. + :returns: ``{metric_name: [NodeInfo dict, ...]}`` or + ``{"all": [NodeInfo dict, ...]}``. + """ + workflow = cls._resolve_workflow(target, target_type) + if not workflow: + return {"all": []} if not metrics else {m: [] for m in metrics} + + if not metrics: + all_nodes = [ + NodeInfo(node_id=node_id, type=node_data.get("type", ""), title=node_data.get("title", "")).model_dump() + for node_id, node_data in workflow.walk_nodes() + ] + return {"all": all_nodes} + + return cls._nodes_for_metrics_from_workflow(workflow, metrics) + + @classmethod + def _resolve_published_workflow( + cls, + target: Union[App, CustomizedSnippet], + target_type: str, + ) -> Workflow | None: + """Resolve only the published workflow for the target (no draft fallback).""" + if target_type == "snippets" and isinstance(target, CustomizedSnippet): + return SnippetService().get_published_workflow(snippet=target) + if target_type == "app" and isinstance(target, App): + return WorkflowService().get_published_workflow(app_model=target) + return None + + @classmethod + def _resolve_workflow( + cls, + target: Union[App, CustomizedSnippet], + target_type: str, + ) -> Workflow | None: + """Resolve the *published* (preferred) or *draft* workflow for the target.""" + if target_type == "snippets" and isinstance(target, CustomizedSnippet): + snippet_service = SnippetService() + workflow = snippet_service.get_published_workflow(snippet=target) + if not workflow: + workflow = snippet_service.get_draft_workflow(snippet=target) + return workflow + elif target_type == "app" and isinstance(target, App): + workflow_service = WorkflowService() + workflow = workflow_service.get_published_workflow(app_model=target) + if not workflow: + workflow = workflow_service.get_draft_workflow(app_model=target) + return workflow + return None + + # ---- Category Resolution ---- + + @classmethod + def _resolve_evaluation_category(cls, default_metrics: list[DefaultMetric]) -> EvaluationCategory: + """Derive evaluation category from default_metrics node_info types. + + Uses the type of the first node_info found in default_metrics. + Falls back to LLM if no metrics are provided. + """ + for metric in default_metrics: + for node_info in metric.node_info_list: + try: + return EvaluationCategory(node_info.type) + except ValueError: + continue + return EvaluationCategory.LLM + + @classmethod + def execute_targets( + cls, + tenant_id: str, + target_type: str, + target_id: str, + input_list: list[EvaluationDatasetInput], + max_workers: int = 5, + ) -> tuple[list[dict[str, NodeRunResult]], list[str | None]]: + """Execute the evaluation target for every test-data item in parallel. + + :param tenant_id: Workspace / tenant ID. + :param target_type: ``"app"`` or ``"snippet"``. + :param target_id: ID of the App or CustomizedSnippet. + :param input_list: All test-data items parsed from the dataset. + :param max_workers: Maximum number of parallel worker threads. + :return: Tuple of (node_results, workflow_run_ids). + node_results: ordered list of ``{node_id: NodeRunResult}`` mappings; + the *i*-th element corresponds to ``input_list[i]``. + workflow_run_ids: ordered list of workflow_run_id strings (or None) + for each input item. + """ + from concurrent.futures import ThreadPoolExecutor + + from flask import Flask, current_app + + flask_app: Flask = current_app._get_current_object() # type: ignore + + def _worker(item: EvaluationDatasetInput) -> tuple[dict[str, NodeRunResult], str | None]: + with flask_app.app_context(): + from models.engine import db + + with Session(db.engine, expire_on_commit=False) as thread_session: + try: + response = cls._run_single_target( + session=thread_session, + target_type=target_type, + target_id=target_id, + item=item, + ) + + workflow_run_id = cls._extract_workflow_run_id(response) + if not workflow_run_id: + logger.warning( + "No workflow_run_id for item %d (target=%s)", + item.index, + target_id, + ) + return {}, None + + node_results = cls._query_node_run_results( + session=thread_session, + tenant_id=tenant_id, + app_id=target_id, + workflow_run_id=workflow_run_id, + ) + return node_results, workflow_run_id + except Exception: + logger.exception( + "Target execution failed for item %d (target=%s)", + item.index, + target_id, + ) + return {}, None + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_worker, item) for item in input_list] + ordered_results: list[dict[str, NodeRunResult]] = [] + ordered_workflow_run_ids: list[str | None] = [] + for future in futures: + try: + node_result, wf_run_id = future.result() + ordered_results.append(node_result) + ordered_workflow_run_ids.append(wf_run_id) + except Exception: + logger.exception("Unexpected error collecting target execution result") + ordered_results.append({}) + ordered_workflow_run_ids.append(None) + + return ordered_results, ordered_workflow_run_ids + + @classmethod + def _run_single_target( + cls, + session: Session, + target_type: str, + target_id: str, + item: EvaluationDatasetInput, + ) -> Mapping[str, object]: + """Execute a single evaluation target with one test-data item. + + Dispatches to the appropriate execution service based on + ``target_type``: + + * ``"snippet"`` → :meth:`SnippetGenerateService.run_published` + * ``"app"`` → :meth:`WorkflowAppGenerator().generate` (blocking mode) + + :returns: The blocking response mapping from the workflow engine. + :raises ValueError: If the target is not found or not published. + """ + from core.app.apps.workflow.app_generator import WorkflowAppGenerator + from core.app.entities.app_invoke_entities import InvokeFrom + from core.evaluation.runners import get_service_account_for_app, get_service_account_for_snippet + + if target_type == "snippet": + from services.snippet_generate_service import SnippetGenerateService + + snippet = session.query(CustomizedSnippet).filter_by(id=target_id).first() + if not snippet: + raise ValueError(f"Snippet {target_id} not found") + + service_account = get_service_account_for_snippet(session, target_id) + + return SnippetGenerateService.run_published( + snippet=snippet, + user=service_account, + args={"inputs": item.inputs}, + invoke_from=InvokeFrom.SERVICE_API, + ) + else: + # target_type == "app" + app = session.query(App).filter_by(id=target_id).first() + if not app: + raise ValueError(f"App {target_id} not found") + + service_account = get_service_account_for_app(session, target_id) + + workflow_service = WorkflowService() + workflow = workflow_service.get_published_workflow(app_model=app) + if not workflow: + raise ValueError(f"No published workflow for app {target_id}") + + response: Mapping[str, object] = WorkflowAppGenerator().generate( + app_model=app, + workflow=workflow, + user=service_account, + args={"inputs": item.inputs}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + call_depth=0, + ) + return response + + @staticmethod + def _extract_workflow_run_id(response: Mapping[str, object]) -> str | None: + """Extract ``workflow_run_id`` from a blocking workflow response.""" + wf_run_id = response.get("workflow_run_id") + if wf_run_id: + return str(wf_run_id) + data = response.get("data") + if isinstance(data, Mapping) and data.get("id"): + return str(data["id"]) + return None + + @staticmethod + def _query_node_run_results( + session: Session, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> dict[str, NodeRunResult]: + """Query all node execution records for a workflow run.""" + from sqlalchemy import asc, select + + from graphon.enums import WorkflowNodeExecutionStatus + from models.workflow import WorkflowNodeExecutionModel + + stmt = ( + WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel)) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by(asc(WorkflowNodeExecutionModel.created_at)) + ) + + node_models: list[WorkflowNodeExecutionModel] = list(session.execute(stmt).scalars().all()) + + result: dict[str, NodeRunResult] = {} + for node in node_models: + # Convert string-keyed metadata to WorkflowNodeExecutionMetadataKey-keyed + raw_metadata = node.execution_metadata_dict + typed_metadata: dict[WorkflowNodeExecutionMetadataKey, object] = {} + for key, val in raw_metadata.items(): + try: + typed_metadata[WorkflowNodeExecutionMetadataKey(key)] = val + except ValueError: + pass # skip unknown metadata keys + + result[node.node_id] = NodeRunResult( + status=WorkflowNodeExecutionStatus(node.status), + inputs=node.inputs_dict or {}, + process_data=node.process_data_dict or {}, + outputs=node.outputs_dict or {}, + metadata=typed_metadata, + error=node.error or "", + ) + return result + + # ---- Dataset Parsing ---- + + @classmethod + def _parse_dataset(cls, xlsx_content: bytes) -> list[EvaluationDatasetInput]: + """Parse evaluation dataset from XLSX bytes.""" + wb = load_workbook(io.BytesIO(xlsx_content), read_only=True) + ws = wb.active + if ws is None: + raise EvaluationDatasetInvalidError("XLSX file has no active worksheet.") + + rows = list(ws.iter_rows(values_only=True)) + if len(rows) < 2: + raise EvaluationDatasetInvalidError("Dataset must have at least a header row and one data row.") + + headers = [str(h).strip() if h is not None else "" for h in rows[0]] + if not headers or headers[0].lower() != "index": + raise EvaluationDatasetInvalidError("First column header must be 'index'.") + + input_headers = headers[1:] # Skip 'index' + items = [] + for row_idx, row in enumerate(rows[1:], start=1): + values = list(row) + if all(v is None or str(v).strip() == "" for v in values): + continue # Skip empty rows + + index_val = values[0] if values else row_idx + try: + index = int(str(index_val)) + except (TypeError, ValueError): + index = row_idx + + inputs: dict[str, Any] = {} + for col_idx, header in enumerate(input_headers): + val = values[col_idx + 1] if col_idx + 1 < len(values) else None + inputs[header] = str(val) if val is not None else "" + + # Extract expected_output column into dedicated field + expected_output = inputs.pop("expected_output", None) + + items.append( + EvaluationDatasetInput( + index=index, + inputs=inputs, + expected_output=expected_output, + ) + ) + + wb.close() + return items + + @classmethod + def execute_retrieval_test_targets( + cls, + dataset_id: str, + account_id: str, + input_list: list[EvaluationDatasetInput], + max_workers: int = 5, + ) -> list[NodeRunResult]: + """Run hit testing against a knowledge base for every input item in parallel. + + Each item must supply a ``query`` key in its ``inputs`` dict. The + retrieved segments are normalised into the same ``NodeRunResult`` format + that :class:`RetrievalEvaluationRunner` expects: + + .. code-block:: python + + NodeRunResult( + inputs={"query": "..."}, + outputs={"result": [{"content": "...", "score": ...}, ...]}, + ) + + :returns: Ordered list of ``NodeRunResult`` — one per input item. + If retrieval fails for an item the result has an empty ``result`` + list so the runner can still persist a (metric-less) row. + """ + from concurrent.futures import ThreadPoolExecutor + + from flask import current_app + + flask_app = current_app._get_current_object() # type: ignore + + def _worker(item: EvaluationDatasetInput) -> NodeRunResult: + with flask_app.app_context(): + from extensions.ext_database import db as flask_db + from models.account import Account + from models.dataset import Dataset + from services.hit_testing_service import HitTestingService + + dataset = flask_db.session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + raise ValueError(f"Dataset {dataset_id} not found") + + account = flask_db.session.query(Account).filter_by(id=account_id).first() + if not account: + raise ValueError(f"Account {account_id} not found") + + query = str(item.inputs.get("query", "")) + response = HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=None, # Use dataset's configured retrieval model + external_retrieval_model={}, + limit=10, + ) + + records = response.get("records", []) + result_list = [ + { + "content": r.get("segment", {}).get("content", "") or r.get("content", ""), + "score": r.get("score"), + } + for r in records + if r.get("segment", {}).get("content") or r.get("content") + ] + + return NodeRunResult( + inputs={"query": query}, + outputs={"result": result_list}, + ) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_worker, item) for item in input_list] + results: list[NodeRunResult] = [] + for item, future in zip(input_list, futures): + try: + results.append(future.result()) + except Exception: + logger.exception("Retrieval test failed for item %d (dataset=%s)", item.index, dataset_id) + results.append(NodeRunResult(inputs={}, outputs={"result": []})) + + return results diff --git a/api/services/feature_service.py b/api/services/feature_service.py index df653e0ba7..9216a7fb99 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -281,7 +281,7 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + features_usage_info = BillingService.get_quota_info(tenant_id) features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..4c784315c7 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/snippet_dsl_service.py b/api/services/snippet_dsl_service.py new file mode 100644 index 0000000000..f074a40f09 --- /dev/null +++ b/api/services/snippet_dsl_service.py @@ -0,0 +1,555 @@ +import json +import logging +import uuid +from collections.abc import Mapping +from datetime import UTC, datetime +from enum import StrEnum +from urllib.parse import urlparse + +import yaml # type: ignore +from packaging import version +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from core.plugin.entities.plugin import PluginDependency +from extensions.ext_redis import redis_client +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from models import Account +from models.snippet import CustomizedSnippet, SnippetType +from models.workflow import Workflow +from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.snippet_service import SNIPPET_FORBIDDEN_NODE_TYPES, SnippetService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "snippet_import_info:" +CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "snippet_check_dependencies:" +IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes +DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB +CURRENT_DSL_VERSION = "0.1.0" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class SnippetImportInfo(BaseModel): + id: str + status: ImportStatus + snippet_id: str | None = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # If imported version is newer than current, always return PENDING + if imported_ver > current_ver: + return ImportStatus.PENDING + + # If imported version is older than current's major, return PENDING + if imported_ver.major < current_ver.major: + return ImportStatus.PENDING + + # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS + if imported_ver.minor < current_ver.minor: + return ImportStatus.COMPLETED_WITH_WARNINGS + + # If imported version equals or is older than current's micro, return COMPLETED + return ImportStatus.COMPLETED + + +class SnippetPendingData(BaseModel): + import_mode: str + yaml_content: str + snippet_id: str | None + + +class CheckDependenciesPendingData(BaseModel): + dependencies: list[PluginDependency] + snippet_id: str | None + + +class SnippetDslService: + def __init__(self, session: Session): + self._session = session + + def import_snippet( + self, + *, + account: Account, + import_mode: str, + yaml_content: str | None = None, + yaml_url: str | None = None, + snippet_id: str | None = None, + name: str | None = None, + description: str | None = None, + ) -> SnippetImportInfo: + """Import a snippet from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content: str = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + parsed_url = urlparse(yaml_url) + if parsed_url.scheme not in ["http", "https"]: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid URL scheme, only http and https are allowed", + ) + response = ssrf_proxy.get(yaml_url, timeout=(10, 30)) + if response.status_code != 200: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Failed to fetch YAML from URL: {response.status_code}", + ) + content = response.text + if len(content) > DSL_MAX_SIZE: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes", + ) + except Exception as e: + logger.exception("Failed to fetch YAML from URL") + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Failed to fetch YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + if len(content) > DSL_MAX_SIZE: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes", + ) + + try: + # Parse YAML + data = yaml.safe_load(content) + if not isinstance(data, dict): + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: expected a dictionary", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + + # Strictly validate kind field + kind = data.get("kind") + if not kind: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing 'kind' field in DSL. Expected 'kind: snippet'.", + ) + if kind != "snippet": + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid DSL kind: expected 'snippet', got '{kind}'. This DSL is for {kind}, not snippet.", + ) + + imported_version = data.get("version", "0.1.0") + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") + status = _check_version_compatibility(imported_version) + + # Extract snippet data + snippet_data = data.get("snippet") + if not snippet_data: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing snippet data in YAML content", + ) + + # Validate workflow nodes - check for forbidden node types + workflow_data = data.get("workflow", {}) + if workflow_data: + graph = workflow_data.get("graph", {}) + nodes = graph.get("nodes", []) + forbidden_nodes_found = [] + for node in nodes: + node_data = node.get("data", {}) + if not node_data: + continue + node_type = node_data.get("type", "") + if node_type in SNIPPET_FORBIDDEN_NODE_TYPES: + forbidden_nodes_found.append(node_type) + + if forbidden_nodes_found: + forbidden_types_str = ", ".join(set(forbidden_nodes_found)) + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Snippet cannot contain the following node types: {forbidden_types_str}", + ) + + # If snippet_id is provided, check if it exists + snippet = None + if snippet_id: + stmt = select(CustomizedSnippet).where( + CustomizedSnippet.id == snippet_id, + CustomizedSnippet.tenant_id == account.current_tenant_id, + ) + snippet = self._session.scalar(stmt) + + if not snippet: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Snippet not found", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + pending_data = SnippetPendingData( + import_mode=import_mode, + yaml_content=content, + snippet_id=snippet_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending_data.model_dump_json(), + ) + + return SnippetImportInfo( + id=import_id, + status=status, + snippet_id=snippet_id, + imported_dsl_version=imported_version, + ) + + # Extract dependencies + dependencies = data.get("dependencies", []) + check_dependencies_pending_data = None + if dependencies: + check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] + + # Create or update snippet + snippet = self._create_or_update_snippet( + snippet=snippet, + data=data, + account=account, + name=name, + description=description, + dependencies=check_dependencies_pending_data, + ) + + return SnippetImportInfo( + id=import_id, + status=status, + snippet_id=snippet.id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import snippet") + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> SnippetImportInfo: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + + pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data + pending = SnippetPendingData.model_validate_json(pending_data_str) + + # Re-import with the pending data + return self.import_snippet( + account=account, + import_mode=pending.import_mode, + yaml_content=pending.yaml_content, + snippet_id=pending.snippet_id, + ) + + except Exception as e: + logger.exception("Failed to confirm import") + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def check_dependencies(self, snippet: CustomizedSnippet) -> CheckDependenciesResult: + """ + Check dependencies for a snippet + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + return CheckDependenciesResult(leaked_dependencies=[]) + + dependencies = self._extract_dependencies_from_workflow(workflow) + leaked_dependencies = DependenciesAnalysisService.generate_dependencies( + tenant_id=snippet.tenant_id, dependencies=dependencies + ) + + return CheckDependenciesResult(leaked_dependencies=leaked_dependencies) + + def _create_or_update_snippet( + self, + *, + snippet: CustomizedSnippet | None, + data: dict, + account: Account, + name: str | None = None, + description: str | None = None, + dependencies: list[PluginDependency] | None = None, + ) -> CustomizedSnippet: + """ + Create or update snippet from DSL data + """ + snippet_data = data.get("snippet", {}) + workflow_data = data.get("workflow", {}) + + # Extract snippet info + snippet_name = name or snippet_data.get("name") or "Untitled Snippet" + snippet_description = description or snippet_data.get("description") or "" + snippet_type_str = snippet_data.get("type", "node") + try: + snippet_type = SnippetType(snippet_type_str) + except ValueError: + snippet_type = SnippetType.NODE + + icon_info = snippet_data.get("icon_info", {}) + input_fields = snippet_data.get("input_fields", []) + + # Create or update snippet + if snippet: + # Update existing snippet + snippet.name = snippet_name + snippet.description = snippet_description + snippet.type = snippet_type.value + snippet.icon_info = icon_info or None + snippet.input_fields = json.dumps(input_fields) if input_fields else None + snippet.updated_by = account.id + snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) + else: + # Create new snippet + snippet = CustomizedSnippet( + tenant_id=account.current_tenant_id, + name=snippet_name, + description=snippet_description, + type=snippet_type.value, + icon_info=icon_info or None, + input_fields=json.dumps(input_fields) if input_fields else None, + created_by=account.id, + ) + self._session.add(snippet) + self._session.flush() + + # Create or update draft workflow + if workflow_data: + graph = workflow_data.get("graph", {}) + + snippet_service = SnippetService() + # Get existing workflow hash if exists + existing_workflow = snippet_service.get_draft_workflow(snippet=snippet) + unique_hash = existing_workflow.unique_hash if existing_workflow else None + + snippet_service.sync_draft_workflow( + snippet=snippet, + graph=graph, + unique_hash=unique_hash, + account=account, + input_fields=input_fields, + ) + + self._session.commit() + return snippet + + def export_snippet_dsl(self, snippet: CustomizedSnippet, include_secret: bool = False) -> str: + """ + Export snippet as DSL + :param snippet: CustomizedSnippet instance + :param include_secret: Whether include secret variable + :return: YAML string + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + icon_info = snippet.icon_info or {} + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "snippet", + "snippet": { + "name": snippet.name, + "description": snippet.description or "", + "type": snippet.type, + "icon_info": icon_info, + "input_fields": snippet.input_fields_list, + }, + } + + self._append_workflow_export_data( + export_data=export_data, snippet=snippet, workflow=workflow, include_secret=include_secret + ) + + return yaml.dump(export_data, allow_unicode=True) # type: ignore + + def _append_workflow_export_data( + self, *, export_data: dict, snippet: CustomizedSnippet, workflow: Workflow, include_secret: bool + ) -> None: + """ + Append workflow export data + """ + workflow_dict = workflow.to_dict(include_secret=include_secret) + # Filter workspace related data from nodes + workflow_dict["environment_variables"] = [] + workflow_dict["conversation_variables"] = [] + + for node in workflow_dict.get("graph", {}).get("nodes", []): + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + dataset_ids = node_data.get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + self._encrypt_dataset_id(dataset_id=dataset_id, tenant_id=snippet.tenant_id) + for dataset_id in dataset_ids + ] + # filter credential id from tool node + if not include_secret and data_type == BuiltinNodeTypes.TOOL: + node_data.pop("credential_id", None) + # filter credential id from agent node + if not include_secret and data_type == BuiltinNodeTypes.AGENT: + for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): + tool.pop("credential_id", None) + + export_data["workflow"] = workflow_dict + dependencies = self._extract_dependencies_from_workflow(workflow) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=snippet.tenant_id, dependencies=dependencies + ) + ] + + def _encrypt_dataset_id(self, *, dataset_id: str, tenant_id: str) -> str: + """ + Encrypt dataset ID for export + """ + # For now, just return the dataset_id as-is + # In the future, we might want to encrypt it + return dataset_id + + def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]: + """ + Extract dependencies from workflow + :param workflow: Workflow instance + :return: dependencies list format like ["langgenius/google"] + """ + graph = workflow.graph_dict + dependencies = self._extract_dependencies_from_workflow_graph(graph) + return dependencies + + def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]: + """ + Extract dependencies from workflow graph + :param graph: Workflow graph + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + for node in graph.get("nodes", []): + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == BuiltinNodeTypes.TOOL: + tool_config = node_data.get("tool_configurations", {}) + provider_type = tool_config.get("provider_type") + provider_name = tool_config.get("provider") + if provider_type and provider_name: + dependencies.append(f"{provider_name}/{provider_name}") + elif data_type == BuiltinNodeTypes.AGENT: + agent_parameters = node_data.get("agent_parameters", {}) + tools = agent_parameters.get("tools", {}).get("value", []) + for tool in tools: + provider_type = tool.get("provider_type") + provider_name = tool.get("provider") + if provider_type and provider_name: + dependencies.append(f"{provider_name}/{provider_name}") + + return dependencies diff --git a/api/services/snippet_generate_service.py b/api/services/snippet_generate_service.py new file mode 100644 index 0000000000..5e0d25c8f7 --- /dev/null +++ b/api/services/snippet_generate_service.py @@ -0,0 +1,421 @@ +""" +Service for generating snippet workflow executions. + +Uses an adapter pattern to bridge CustomizedSnippet with the App-based +WorkflowAppGenerator. The adapter (_SnippetAsApp) provides the minimal App-like +interface needed by the generator, avoiding modifications to core workflow +infrastructure. + +Key invariants: +- Snippets always run as WORKFLOW mode (not CHAT or ADVANCED_CHAT). +- The adapter maps snippet.id to app_id in workflow execution records. +- Snippet debugging has no rate limiting (max_active_requests = 0). + +Supported execution modes: +- Full workflow run (generate): Runs the entire draft workflow as SSE stream. +- Single node run (run_draft_node): Synchronous single-step debugging for regular nodes. +- Single iteration run (generate_single_iteration): SSE stream for iteration container nodes. +- Single loop run (generate_single_loop): SSE stream for loop container nodes. +""" + +import json +import logging +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Union + +from sqlalchemy.orm import make_transient + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from factories import file_factory +from graphon.file.models import File +from models import Account +from models.model import AppMode, EndUser +from models.snippet import CustomizedSnippet +from models.workflow import Workflow, WorkflowNodeExecutionModel +from services.snippet_service import SnippetService +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + + +class _SnippetAsApp: + """ + Minimal adapter that wraps a CustomizedSnippet to satisfy the App-like + interface required by WorkflowAppGenerator, WorkflowAppConfigManager, + and WorkflowService.run_draft_workflow_node. + + Used properties: + - id: maps to snippet.id (stored as app_id in workflows table) + - tenant_id: maps to snippet.tenant_id + - mode: hardcoded to AppMode.WORKFLOW since snippets always run as workflows + - max_active_requests: defaults to 0 (no limit) for snippet debugging + - app_model_config_id: None (snippets don't have app model configs) + """ + + id: str + tenant_id: str + mode: str + max_active_requests: int + app_model_config_id: str | None + + def __init__(self, snippet: CustomizedSnippet) -> None: + self.id = snippet.id + self.tenant_id = snippet.tenant_id + self.mode = AppMode.WORKFLOW.value + self.max_active_requests = 0 + self.app_model_config_id = None + + +class SnippetGenerateService: + """ + Service for running snippet workflow executions. + + Adapts CustomizedSnippet to work with the existing App-based + WorkflowAppGenerator infrastructure, avoiding duplication of the + complex workflow execution pipeline. + """ + + # Specific ID for the injected virtual Start node so it can be recognised + _VIRTUAL_START_NODE_ID = "__snippet_virtual_start__" + + @classmethod + def generate( + cls, + snippet: CustomizedSnippet, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: + """ + Run a snippet's draft workflow. + + Retrieves the draft workflow, adapts the snippet to an App-like proxy, + then delegates execution to WorkflowAppGenerator. + + If the workflow graph has no Start node, a virtual Start node is injected + in-memory so that: + 1. Graph validation passes (root node must have execution_type=ROOT). + 2. User inputs are processed into the variable pool by the StartNode logic. + + :param snippet: CustomizedSnippet instance + :param user: Account or EndUser initiating the run + :param args: Workflow inputs (must include "inputs" key) + :param invoke_from: Source of invocation (typically DEBUGGER) + :param streaming: Whether to stream the response + :return: Blocking response mapping or SSE streaming generator + :raises ValueError: If the snippet has no draft workflow + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + raise ValueError("Workflow not initialized") + + # Inject a virtual Start node when the graph doesn't have one. + workflow = cls._ensure_start_node(workflow, snippet) + + # Adapt snippet to App-like interface for WorkflowAppGenerator + app_proxy = _SnippetAsApp(snippet) + + return WorkflowAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().generate( + app_model=app_proxy, # type: ignore[arg-type] + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + ) + ) + + @classmethod + def run_published( + cls, + snippet: CustomizedSnippet, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + ) -> Mapping[str, Any]: + """ + Run a snippet's published workflow in non-streaming (blocking) mode. + + Similar to :meth:`generate` but targets the published workflow instead + of the draft, and returns the raw blocking response without SSE + wrapping. Designed for programmatic callers such as evaluation runners. + + :param snippet: CustomizedSnippet instance (must be published) + :param user: Account or EndUser initiating the run + :param args: Workflow inputs (must include "inputs" key) + :param invoke_from: Source of invocation + :return: Blocking response mapping with workflow outputs + :raises ValueError: If the snippet has no published workflow + """ + snippet_service = SnippetService() + workflow = snippet_service.get_published_workflow(snippet) + if not workflow: + raise ValueError("No published workflow found for snippet") + + # Inject a virtual Start node when the graph doesn't have one. + workflow = cls._ensure_start_node(workflow, snippet) + + app_proxy = _SnippetAsApp(snippet) + + response: Mapping[str, Any] = WorkflowAppGenerator().generate( + app_model=app_proxy, # type: ignore[arg-type] + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=False, + ) + return response + + @classmethod + def ensure_start_node_for_worker(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow: + """Public wrapper for worker-thread start-node injection.""" + return cls._ensure_start_node(workflow, snippet) + + @classmethod + def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow: + """ + Return *workflow* with a Start node. + + If the graph already contains a Start node, the original workflow is + returned unchanged. Otherwise a virtual Start node is injected and the + workflow object is detached from the SQLAlchemy session so the in-memory + change is never flushed to the database. + """ + graph_dict = workflow.graph_dict + nodes: list[dict[str, Any]] = graph_dict.get("nodes", []) + + has_start = any(node.get("data", {}).get("type") == "start" for node in nodes) + if has_start: + return workflow + + modified_graph = cls._inject_virtual_start_node( + graph_dict=graph_dict, + input_fields=snippet.input_fields_list, + ) + + # Detach from session to prevent accidental DB persistence of the + # modified graph. All attributes remain accessible for read. + make_transient(workflow) + workflow.graph = json.dumps(modified_graph) + return workflow + + @classmethod + def _inject_virtual_start_node( + cls, + graph_dict: Mapping[str, Any], + input_fields: list[dict[str, Any]], + ) -> dict[str, Any]: + """ + Build a new graph dict with a virtual Start node prepended. + + The virtual Start node is wired to every existing node that has no + incoming edges (i.e. the current root candidates). This guarantees: + + :param graph_dict: Original graph configuration. + :param input_fields: Snippet input field definitions from + ``CustomizedSnippet.input_fields_list``. + :return: New graph dict containing the virtual Start node and edges. + """ + nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", [])) + edges: list[dict[str, Any]] = list(graph_dict.get("edges", [])) + + # Identify nodes with no incoming edges. + nodes_with_incoming: set[str] = set() + for edge in edges: + target = edge.get("target") + if isinstance(target, str): + nodes_with_incoming.add(target) + root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming] + + # Build Start node ``variables`` from snippet input fields. + start_variables: list[dict[str, Any]] = [] + for field in input_fields: + var: dict[str, Any] = { + "variable": field.get("variable", ""), + "label": field.get("label", field.get("variable", "")), + "type": field.get("type", "text-input"), + "required": field.get("required", False), + "options": field.get("options", []), + } + if field.get("max_length") is not None: + var["max_length"] = field["max_length"] + start_variables.append(var) + + virtual_start_node: dict[str, Any] = { + "id": cls._VIRTUAL_START_NODE_ID, + "data": { + "type": "start", + "title": "Start", + "variables": start_variables, + }, + } + + # Create edges from virtual Start to each root candidate. + new_edges: list[dict[str, Any]] = [ + { + "source": cls._VIRTUAL_START_NODE_ID, + "sourceHandle": "source", + "target": root_id, + "targetHandle": "target", + } + for root_id in root_candidate_ids + ] + + return { + **graph_dict, + "nodes": [virtual_start_node, *nodes], + "edges": [*edges, *new_edges], + } + + @classmethod + def run_draft_node( + cls, + snippet: CustomizedSnippet, + node_id: str, + user_inputs: Mapping[str, Any], + account: Account, + query: str = "", + files: Sequence[File] | None = None, + ) -> WorkflowNodeExecutionModel: + """ + Run a single node in a snippet's draft workflow (single-step debugging). + + Retrieves the draft workflow, adapts the snippet to an App-like proxy, + parses file inputs, then delegates to WorkflowService.run_draft_workflow_node. + + :param snippet: CustomizedSnippet instance + :param node_id: ID of the node to run + :param user_inputs: User input values for the node + :param account: Account initiating the run + :param query: Optional query string + :param files: Optional parsed file objects + :return: WorkflowNodeExecutionModel with execution results + :raises ValueError: If the snippet has no draft workflow + """ + snippet_service = SnippetService() + draft_workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + app_proxy = _SnippetAsApp(snippet) + + workflow_service = WorkflowService() + return workflow_service.run_draft_workflow_node( + app_model=app_proxy, # type: ignore[arg-type] + draft_workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + account=account, + query=query, + files=files, + ) + + @classmethod + def generate_single_iteration( + cls, + snippet: CustomizedSnippet, + user: Union[Account, EndUser], + node_id: str, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: + """ + Run a single iteration node in a snippet's draft workflow. + + Iteration nodes are container nodes that execute their sub-graph multiple + times, producing many events. Therefore, this uses the full WorkflowAppGenerator + pipeline with SSE streaming (unlike regular single-step node run). + + :param snippet: CustomizedSnippet instance + :param user: Account or EndUser initiating the run + :param node_id: ID of the iteration node to run + :param args: Dict containing 'inputs' key with iteration input data + :param streaming: Whether to stream the response (should be True) + :return: SSE streaming generator + :raises ValueError: If the snippet has no draft workflow + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + raise ValueError("Workflow not initialized") + + app_proxy = _SnippetAsApp(snippet) + + return WorkflowAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_iteration_generate( + app_model=app_proxy, # type: ignore[arg-type] + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) + ) + + @classmethod + def generate_single_loop( + cls, + snippet: CustomizedSnippet, + user: Union[Account, EndUser], + node_id: str, + args: Any, + streaming: bool = True, + ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: + """ + Run a single loop node in a snippet's draft workflow. + + Loop nodes are container nodes that execute their sub-graph repeatedly, + producing many events. Therefore, this uses the full WorkflowAppGenerator + pipeline with SSE streaming (unlike regular single-step node run). + + :param snippet: CustomizedSnippet instance + :param user: Account or EndUser initiating the run + :param node_id: ID of the loop node to run + :param args: Pydantic model with 'inputs' attribute containing loop input data + :param streaming: Whether to stream the response (should be True) + :return: SSE streaming generator + :raises ValueError: If the snippet has no draft workflow + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + raise ValueError("Workflow not initialized") + + app_proxy = _SnippetAsApp(snippet) + + return WorkflowAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_loop_generate( + app_model=app_proxy, # type: ignore[arg-type] + workflow=workflow, + node_id=node_id, + user=user, + args=args, # type: ignore[arg-type] + streaming=streaming, + ) + ) + + @staticmethod + def parse_files(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]: + """ + Parse file mappings into File objects based on workflow configuration. + + :param workflow: Workflow instance for file upload config + :param files: Raw file mapping dicts + :return: Parsed File objects + """ + files = files or [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config is None: + return [] + return file_factory.build_from_mappings( + mappings=files, + tenant_id=workflow.tenant_id, + config=file_extra_config, + ) diff --git a/api/services/snippet_service.py b/api/services/snippet_service.py new file mode 100644 index 0000000000..a2cdc23f3d --- /dev/null +++ b/api/services/snippet_service.py @@ -0,0 +1,608 @@ +import json +import logging +from collections.abc import Mapping, Sequence +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import func, select +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, NodeType +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models import Account +from models.enums import WorkflowRunTriggeredFrom +from models.snippet import CustomizedSnippet, SnippetType +from models.workflow import ( + Workflow, + WorkflowNodeExecutionModel, + WorkflowRun, + WorkflowType, +) +from repositories.factory import DifyAPIRepositoryFactory +from services.errors.app import WorkflowHashNotEqualError + +logger = logging.getLogger(__name__) + +# Node types not allowed in snippet workflows (sync, publish, DSL import). +SNIPPET_FORBIDDEN_NODE_TYPES: frozenset[str] = frozenset( + { + BuiltinNodeTypes.START, + BuiltinNodeTypes.HUMAN_INPUT, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + } +) + + +class SnippetService: + """Service for managing customized snippets.""" + + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize SnippetService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + @staticmethod + def validate_snippet_graph_forbidden_nodes(graph: Mapping[str, Any]) -> None: + """Reject graphs that contain node types not allowed in snippets.""" + nodes = graph.get("nodes") or [] + disallowed: list[tuple[str, str]] = [] + for node in nodes: + if not isinstance(node, dict): + continue + node_data = node.get("data") or {} + node_type = node_data.get("type") + if not isinstance(node_type, str): + continue + if node_type in SNIPPET_FORBIDDEN_NODE_TYPES: + node_id = node.get("id") + disallowed.append((str(node_id) if node_id is not None else "?", node_type)) + if not disallowed: + return + detail = ", ".join(f"{nid}:{t}" for nid, t in disallowed) + raise ValueError( + "Snippet workflow cannot contain start, human-input, or knowledge-retrieval nodes. " + f"Found: {detail}" + ) + + # --- CRUD Operations --- + + @staticmethod + def get_snippets( + *, + tenant_id: str, + page: int = 1, + limit: int = 20, + keyword: str | None = None, + is_published: bool | None = None, + creators: list[str] | None = None, + ) -> tuple[Sequence[CustomizedSnippet], int, bool]: + """ + Get paginated list of snippets with optional search. + + :param tenant_id: Tenant ID + :param page: Page number (1-indexed) + :param limit: Number of items per page + :param keyword: Optional search keyword for name/description + :param is_published: Optional filter by published status (True/False/None for all) + :param creators: Optional filter by creator account IDs + :return: Tuple of (snippets list, total count, has_more flag) + """ + stmt = ( + select(CustomizedSnippet) + .where(CustomizedSnippet.tenant_id == tenant_id) + .order_by(CustomizedSnippet.created_at.desc()) + ) + + if keyword: + stmt = stmt.where( + CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%") + ) + + if is_published is not None: + stmt = stmt.where(CustomizedSnippet.is_published == is_published) + + if creators: + stmt = stmt.where(CustomizedSnippet.created_by.in_(creators)) + + # Get total count + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = db.session.scalar(count_stmt) or 0 + + # Apply pagination + stmt = stmt.limit(limit + 1).offset((page - 1) * limit) + snippets = list(db.session.scalars(stmt).all()) + + has_more = len(snippets) > limit + if has_more: + snippets = snippets[:-1] + + return snippets, total, has_more + + @staticmethod + def get_snippet_by_id( + *, + snippet_id: str, + tenant_id: str, + ) -> CustomizedSnippet | None: + """ + Get snippet by ID with tenant isolation. + + :param snippet_id: Snippet ID + :param tenant_id: Tenant ID + :return: CustomizedSnippet or None + """ + return ( + db.session.query(CustomizedSnippet) + .where( + CustomizedSnippet.id == snippet_id, + CustomizedSnippet.tenant_id == tenant_id, + ) + .first() + ) + + @staticmethod + def create_snippet( + *, + tenant_id: str, + name: str, + description: str | None, + snippet_type: SnippetType, + icon_info: dict | None, + input_fields: list[dict] | None, + account: Account, + ) -> CustomizedSnippet: + """ + Create a new snippet. + + :param tenant_id: Tenant ID + :param name: Snippet name (must be unique per tenant) + :param description: Snippet description + :param snippet_type: Type of snippet (node or group) + :param icon_info: Icon information + :param input_fields: Input field definitions + :param account: Creator account + :return: Created CustomizedSnippet + :raises ValueError: If name already exists + """ + # Check if name already exists for this tenant + existing = ( + db.session.query(CustomizedSnippet) + .where( + CustomizedSnippet.tenant_id == tenant_id, + CustomizedSnippet.name == name, + ) + .first() + ) + if existing: + raise ValueError(f"Snippet with name '{name}' already exists") + + snippet = CustomizedSnippet( + tenant_id=tenant_id, + name=name, + description=description or "", + type=snippet_type.value, + icon_info=icon_info, + input_fields=json.dumps(input_fields) if input_fields else None, + created_by=account.id, + ) + + db.session.add(snippet) + db.session.commit() + + return snippet + + @staticmethod + def update_snippet( + *, + session: Session, + snippet: CustomizedSnippet, + account_id: str, + data: dict, + ) -> CustomizedSnippet: + """ + Update snippet attributes. + + :param session: Database session + :param snippet: Snippet to update + :param account_id: ID of account making the update + :param data: Dictionary of fields to update + :return: Updated CustomizedSnippet + """ + if "name" in data: + # Check if new name already exists for this tenant + existing = ( + session.query(CustomizedSnippet) + .where( + CustomizedSnippet.tenant_id == snippet.tenant_id, + CustomizedSnippet.name == data["name"], + CustomizedSnippet.id != snippet.id, + ) + .first() + ) + if existing: + raise ValueError(f"Snippet with name '{data['name']}' already exists") + snippet.name = data["name"] + + if "description" in data: + snippet.description = data["description"] + + if "icon_info" in data: + snippet.icon_info = data["icon_info"] + + snippet.updated_by = account_id + snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(snippet) + return snippet + + @staticmethod + def delete_snippet( + *, + session: Session, + snippet: CustomizedSnippet, + ) -> bool: + """ + Delete a snippet. + + :param session: Database session + :param snippet: Snippet to delete + :return: True if deleted successfully + """ + session.delete(snippet) + return True + + # --- Workflow Operations --- + + def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None: + """ + Get draft workflow for snippet. + + :param snippet: CustomizedSnippet instance + :return: Draft Workflow or None + """ + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.version == "draft", + ) + .first() + ) + return workflow + + def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None: + """ + Get published workflow for snippet. + + :param snippet: CustomizedSnippet instance + :return: Published Workflow or None + """ + if not snippet.workflow_id: + return None + + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.id == snippet.workflow_id, + ) + .first() + ) + return workflow + + def sync_draft_workflow( + self, + *, + snippet: CustomizedSnippet, + graph: dict, + unique_hash: str | None, + account: Account, + input_fields: list[dict] | None = None, + ) -> Workflow: + """ + Sync draft workflow for snippet. + + Snippet workflows do not persist environment variables (always empty) or + conversation variables (always empty). + + :param snippet: CustomizedSnippet instance + :param graph: Workflow graph configuration + :param unique_hash: Hash for conflict detection + :param account: Account making the change + :param input_fields: Input fields for snippet + :return: Synced Workflow + :raises WorkflowHashNotEqualError: If hash mismatch + """ + SnippetService.validate_snippet_graph_forbidden_nodes(graph) + + workflow = self.get_draft_workflow(snippet=snippet) + + if workflow and workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + + # Create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + features="{}", + type=WorkflowType.SNIPPET.value, + version="draft", + graph=json.dumps(graph), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + db.session.add(workflow) + db.session.flush() + else: + # Update existing draft workflow + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = [] + workflow.conversation_variables = [] + + # Update snippet's input_fields if provided + if input_fields is not None: + snippet.input_fields = json.dumps(input_fields) + snippet.updated_by = account.id + snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) + + db.session.commit() + return workflow + + def publish_workflow( + self, + *, + session: Session, + snippet: CustomizedSnippet, + account: Account, + ) -> Workflow: + """ + Publish the draft workflow as a new version. + + :param session: Database session + :param snippet: CustomizedSnippet instance + :param account: Account making the change + :return: Published Workflow + :raises ValueError: If no draft workflow exists + """ + draft_workflow_stmt = select(Workflow).where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.version == "draft", + ) + draft_workflow = session.scalar(draft_workflow_stmt) + if not draft_workflow: + raise ValueError("No valid workflow found.") + + SnippetService.validate_snippet_graph_forbidden_nodes(draft_workflow.graph_dict) + + # Create new published workflow + workflow = Workflow.new( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + type=draft_workflow.type, + version=str(datetime.now(UTC).replace(tzinfo=None)), + graph=draft_workflow.graph, + features=draft_workflow.features, + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + marked_name="", + marked_comment="", + ) + session.add(workflow) + + # Update snippet version + snippet.version += 1 + snippet.is_published = True + snippet.workflow_id = workflow.id + snippet.updated_by = account.id + session.add(snippet) + + return workflow + + def get_all_published_workflows( + self, + *, + session: Session, + snippet: CustomizedSnippet, + page: int, + limit: int, + ) -> tuple[Sequence[Workflow], bool]: + """ + Get all published workflow versions for snippet. + + :param session: Database session + :param snippet: CustomizedSnippet instance + :param page: Page number + :param limit: Items per page + :return: Tuple of (workflows list, has_more flag) + """ + if not snippet.workflow_id: + return [], False + + stmt = ( + select(Workflow) + .where( + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.version != "draft", + ) + .order_by(Workflow.version.desc()) + .limit(limit + 1) + .offset((page - 1) * limit) + ) + + workflows = list(session.scalars(stmt).all()) + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + + # --- Default Block Configs --- + + def get_default_block_configs(self) -> list[dict]: + """ + Get default block configurations for all node types. + + :return: List of default configurations + """ + default_block_configs: list[dict[str, Any]] = [] + for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): + node_class = node_class_mapping[LATEST_VERSION] + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(dict(default_config)) + + return default_block_configs + + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: + """ + Get default config for specific node type. + + :param node_type: Node type string + :param filters: Optional filters + :return: Default configuration or None + """ + node_type_enum = NodeType(node_type) + + if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + return None + + node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config + + # --- Workflow Run Operations --- + + def get_snippet_workflow_runs( + self, + *, + snippet: CustomizedSnippet, + args: dict, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs for snippet. + + :param snippet: CustomizedSnippet instance + :param args: Request arguments (last_id, limit) + :return: InfiniteScrollPagination result + """ + limit = int(args.get("limit", 20)) + last_id = args.get("last_id") + + triggered_from_values = [ + WorkflowRunTriggeredFrom.DEBUGGING, + ] + + return self._workflow_run_repo.get_paginated_workflow_runs( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + triggered_from=triggered_from_values, + limit=limit, + last_id=last_id, + ) + + def get_snippet_workflow_run( + self, + *, + snippet: CustomizedSnippet, + run_id: str, + ) -> WorkflowRun | None: + """ + Get workflow run details. + + :param snippet: CustomizedSnippet instance + :param run_id: Workflow run ID + :return: WorkflowRun or None + """ + return self._workflow_run_repo.get_workflow_run_by_id( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + run_id=run_id, + ) + + def get_snippet_workflow_run_node_executions( + self, + *, + snippet: CustomizedSnippet, + run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get workflow run node execution list. + + :param snippet: CustomizedSnippet instance + :param run_id: Workflow run ID + :return: List of WorkflowNodeExecutionModel + """ + workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id) + if not workflow_run: + return [] + + node_executions = self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + workflow_run_id=workflow_run.id, + ) + + return node_executions + + # --- Node Execution Operations --- + + def get_snippet_node_last_run( + self, + *, + snippet: CustomizedSnippet, + workflow: Workflow, + node_id: str, + ) -> WorkflowNodeExecutionModel | None: + """ + Get the most recent execution for a specific node in a snippet workflow. + + :param snippet: CustomizedSnippet instance + :param workflow: Workflow instance + :param node_id: Node identifier + :return: WorkflowNodeExecutionModel or None + """ + return self._node_execution_service_repo.get_node_last_execution( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + workflow_id=workflow.id, + node_id=node_id, + ) + + # --- Use Count --- + + @staticmethod + def increment_use_count( + *, + session: Session, + snippet: CustomizedSnippet, + ) -> None: + """ + Increment the use_count when snippet is used. + + :param session: Database session + :param snippet: CustomizedSnippet instance + """ + snippet.use_count += 1 + session.add(snippet) diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index bb767a6759..c782bffad4 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -38,6 +38,7 @@ from models.workflow import Workflow from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData @@ -819,9 +820,9 @@ class WebhookService: user_id=None, ) - # consume quota before triggering workflow execution + # reserve quota before triggering workflow execution try: - QuotaType.TRIGGER.consume(webhook_trigger.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) logger.info( @@ -832,11 +833,16 @@ class WebhookService: raise # Trigger workflow execution asynchronously - AsyncWorkflowService.trigger_workflow_async( - session, - end_user, - trigger_data, - ) + try: + AsyncWorkflowService.trigger_workflow_async( + session, + end_user, + trigger_data, + ) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index b5ab176ad2..5ab3430883 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -23,17 +23,28 @@ class LogView: """Lightweight wrapper for WorkflowAppLog with computed details. - Exposes `details_` for marshalling to `details` in API response + - Exposes `evaluation_` for marshalling evaluation metrics in API response - Proxies all other attributes to the underlying `WorkflowAppLog` """ - def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None): + def __init__( + self, + log: WorkflowAppLog, + details: LogViewDetails | None, + evaluation: list[dict] | None = None, + ): self.log = log self.details_ = details + self.evaluation_ = evaluation @property def details(self) -> LogViewDetails | None: return self.details_ + @property + def evaluation(self) -> list[dict] | None: + return self.evaluation_ + def __getattr__(self, name): return getattr(self.log, name) @@ -170,12 +181,20 @@ class WorkflowAppService: # Execute query and get items if detail: rows = session.execute(offset_stmt).all() - items = [ - LogView(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)}) + logs_with_details = [ + (log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)}) for log, meta_val in rows ] else: - items = [LogView(log, None) for log in session.scalars(offset_stmt).all()] + logs_with_details = [(log, None) for log in session.scalars(offset_stmt).all()] + + workflow_run_ids = [log.workflow_run_id for log, _ in logs_with_details] + eval_map = self._batch_query_evaluation_metrics(session, workflow_run_ids) + + items = [ + LogView(log, details, evaluation=eval_map.get(log.workflow_run_id)) + for log, details in logs_with_details + ] return { "page": page, "limit": limit, @@ -257,6 +276,45 @@ class WorkflowAppService: "data": items, } + @staticmethod + def _batch_query_evaluation_metrics( + session: Session, + workflow_run_ids: list[str], + ) -> dict[str, list[dict[str, Any]]]: + """Return evaluation metrics keyed by workflow_run_id. + + Only returns metrics from completed evaluation runs. If a workflow + run was not part of any evaluation (or the evaluation has not + completed), it will be absent from the result dict. + """ + from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus + + if not workflow_run_ids: + return {} + + non_null_ids = [wid for wid in workflow_run_ids if wid] + if not non_null_ids: + return {} + + stmt = ( + select(EvaluationRunItem.workflow_run_id, EvaluationRunItem.metrics) + .join(EvaluationRun, EvaluationRun.id == EvaluationRunItem.evaluation_run_id) + .where( + EvaluationRunItem.workflow_run_id.in_(non_null_ids), + EvaluationRun.status == EvaluationRunStatus.COMPLETED, + ) + ) + rows = session.execute(stmt).all() + + result: dict[str, list[dict[str, Any]]] = {} + for wf_run_id, metrics_json in rows: + if wf_run_id and metrics_json: + parsed: list[dict[str, Any]] = json.loads(metrics_json) + existing = result.get(wf_run_id, []) + existing.extend(parsed) + result[wf_run_id] = existing + return result + def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]: metadata: dict[str, Any] | None = self._safe_json_loads(meta_val) if not metadata: diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 2cc6e21574..fae5dea3cb 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -1,7 +1,7 @@ import dataclasses import json import logging -from collections.abc import Mapping, Sequence +from collections.abc import Mapping, Sequence, Set from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import StrEnum @@ -271,12 +271,20 @@ class WorkflowDraftVariableService: ) def list_variables_without_values( - self, app_id: str, page: int, limit: int, user_id: str + self, + app_id: str, + page: int, + limit: int, + user_id: str, + *, + exclude_node_ids: Set[str] | None = None, ) -> WorkflowDraftVariableList: criteria = [ WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.user_id == user_id, ] + if exclude_node_ids: + criteria.append(WorkflowDraftVariable.node_id.notin_(list(exclude_node_ids))) total = None base_stmt = select(WorkflowDraftVariable).where(*criteria) if page == 1: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0e1864ce9a..76cf4dc687 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast -from graphon.entities import WorkflowNodeExecution +from graphon.entities import GraphInitParams, WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import ( @@ -30,7 +30,7 @@ from graphon.variable_loader import load_into_variable_pool from graphon.variables import VariableBase from graphon.variables.input_entities import VariableEntityType from graphon.variables.variables import Variable -from sqlalchemy import exists, select +from sqlalchemy import and_, exists, or_, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -42,7 +42,7 @@ from core.entities import PluginCredentialType from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.trigger.constants import is_trigger_node_type +from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type from core.workflow.human_input_compat import ( DeliveryChannelConfig, normalize_human_input_node_data_for_graph, @@ -65,6 +65,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now +from libs.helper import escape_like_pattern from models import Account from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode @@ -99,6 +100,15 @@ class WorkflowService: Workflow Service """ + # Centralized unsupported node types for evaluation workflow publishing. + # Keep this set updated when evaluation workflow constraints change. + EVALUATION_UNSUPPORTED_NODE_TYPES: frozenset[str] = frozenset( + { + BuiltinNodeTypes.HUMAN_INPUT, + *TRIGGER_NODE_TYPES, + } + ) + def __init__(self, session_maker: sessionmaker | None = None): """Initialize WorkflowService with repository dependencies.""" if session_maker is None: @@ -237,6 +247,59 @@ class WorkflowService: return workflows, has_more + def list_published_evaluation_workflows( + self, + *, + session: Session, + tenant_id: str, + page: int, + limit: int, + user_id: str | None, + named_only: bool = False, + keyword: str | None = None, + ) -> tuple[Sequence[Workflow], bool]: + """ + List published evaluation-type workflows for a tenant (cross-app), excluding draft rows. + + When ``keyword`` is non-empty, match workflows whose marked name or parent app name contains + the substring (case-insensitive, LIKE wildcards escaped). + """ + stmt = select(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.type == WorkflowType.EVALUATION, + Workflow.version != Workflow.VERSION_DRAFT, + ) + + if user_id: + stmt = stmt.where(Workflow.created_by == user_id) + + if named_only: + stmt = stmt.where(Workflow.marked_name != "") + + keyword_stripped = keyword.strip() if keyword else "" + if keyword_stripped: + escaped = escape_like_pattern(keyword_stripped) + pattern = f"%{escaped}%" + stmt = stmt.join( + App, + and_(Workflow.app_id == App.id, App.tenant_id == tenant_id), + ).where( + or_( + Workflow.marked_name.ilike(pattern, escape="\\"), + App.name.ilike(pattern, escape="\\"), + ) + ) + + stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit) + + workflows = session.scalars(stmt).all() + + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + def sync_draft_workflow( self, *, @@ -400,6 +463,127 @@ class WorkflowService: # return new workflow return workflow + def publish_evaluation_workflow( + self, + *, + session: Session, + app_model: App, + account: Account, + marked_name: str = "", + marked_comment: str = "", + ) -> Workflow: + """Publish draft workflow as an evaluation workflow version. + + Compared to standard publish: + - force published workflow type to ``evaluation``; + - reject graphs containing trigger or human-input nodes. + """ + draft_workflow_stmt = select(Workflow).where( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) + draft_workflow = session.scalar(draft_workflow_stmt) + if not draft_workflow: + raise ValueError("No valid workflow found.") + + # Validate credentials before publishing, for credential policy check + from services.feature_service import FeatureService + + if FeatureService.get_system_features().plugin_manager.enabled: + self._validate_workflow_credentials(draft_workflow) + + # validate graph structure + self.validate_graph_structure(graph=draft_workflow.graph_dict) + self._validate_evaluation_workflow_nodes(draft_workflow) + + workflow = Workflow.new( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.EVALUATION.value, + version=Workflow.version_from_datetime(naive_utc_now()), + graph=draft_workflow.graph, + created_by=account.id, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + marked_name=marked_name, + marked_comment=marked_comment, + rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + features=draft_workflow.features, + ) + + session.add(workflow) + + # trigger app workflow events + app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + + return workflow + + def convert_published_workflow_type( + self, + *, + session: Session, + app_model: App, + target_type: WorkflowType, + account: Account, + ) -> Workflow: + """ + Convert a published workflow type in-place. + + This endpoint only supports conversion between standard workflow and evaluation workflow. + """ + if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}: + raise ValueError("target_type must be either 'workflow' or 'evaluation'") + + if not app_model.workflow_id: + raise WorkflowNotFoundError("Published workflow not found") + + stmt = select(Workflow).where( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + workflow = session.scalar(stmt) + if not workflow: + raise WorkflowNotFoundError("Published workflow not found") + + if workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.") + + if workflow.type == target_type: + return workflow + + if target_type == WorkflowType.EVALUATION: + self._validate_evaluation_workflow_nodes(workflow) + + workflow.type = target_type + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + + return workflow + + @staticmethod + def _validate_evaluation_workflow_nodes(workflow: Workflow) -> None: + """Ensure evaluation workflows do not contain unsupported node types.""" + disallowed_nodes: list[tuple[str, str]] = [] + for node_id, node_data in workflow.walk_nodes(): + node_type = node_data.get("type") + if not isinstance(node_type, str): + continue + if node_type in WorkflowService.EVALUATION_UNSUPPORTED_NODE_TYPES: + disallowed_nodes.append((node_id, node_type)) + + if not disallowed_nodes: + return + + formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes) + raise ValueError( + "Evaluation workflow cannot contain trigger or human-input nodes. " + f"Found disallowed nodes: {formatted_nodes}" + ) + def _validate_workflow_credentials(self, workflow: Workflow) -> None: """ Validate all credentials in workflow nodes before publishing. @@ -1550,8 +1734,8 @@ def _setup_variable_pool( "workflow_execution_id": str(uuid.uuid4()), } - # Only add chatflow-specific variables for non-workflow types. - if workflow.type != WorkflowType.WORKFLOW: + # Only add chatflow-specific variables for chat-like workflow types. + if workflow.type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}: system_variable_values.update( { "query": query, diff --git a/api/tasks/evaluation_task.py b/api/tasks/evaluation_task.py new file mode 100644 index 0000000000..4e3f7acb2e --- /dev/null +++ b/api/tasks/evaluation_task.py @@ -0,0 +1,541 @@ +import io +import json +import logging +from typing import Any + +from celery import shared_task +from openpyxl import Workbook +from openpyxl.styles import Alignment, Border, Font, PatternFill, Side +from openpyxl.utils import get_column_letter + +from configs import dify_config +from core.evaluation.base_evaluation_instance import BaseEvaluationInstance +from core.evaluation.entities.evaluation_entity import ( + EvaluationCategory, + EvaluationDatasetInput, + EvaluationItemResult, + EvaluationRunData, + NodeInfo, +) +from core.evaluation.entities.judgment_entity import JudgmentConfig +from core.evaluation.evaluation_manager import EvaluationManager +from core.evaluation.judgment.processor import JudgmentProcessor +from core.evaluation.runners.agent_evaluation_runner import AgentEvaluationRunner +from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner +from core.evaluation.runners.llm_evaluation_runner import LLMEvaluationRunner +from core.evaluation.runners.retrieval_evaluation_runner import RetrievalEvaluationRunner +from core.evaluation.runners.snippet_evaluation_runner import SnippetEvaluationRunner +from core.evaluation.runners.workflow_evaluation_runner import WorkflowEvaluationRunner +from extensions.ext_database import db +from graphon.node_events import NodeRunResult +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole +from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus +from models.model import UploadFile +from services.evaluation_service import EvaluationService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="evaluation") +def run_evaluation(run_data_dict: dict[str, Any]) -> None: + """Celery task for running evaluations asynchronously. + + Workflow: + 1. Deserialize EvaluationRunData + 2. Execute target and collect node results + 3. Evaluate metrics via runners (one per metric-node pair) + 4. Merge results per test-data row (1 item = 1 EvaluationRunItem) + 5. Apply judgment conditions + 6. Persist results + generate result XLSX + 7. Update EvaluationRun status to COMPLETED + """ + run_data = EvaluationRunData.model_validate(run_data_dict) + + with db.engine.connect() as connection: + from sqlalchemy.orm import Session + + session = Session(bind=connection) + + try: + _execute_evaluation(session, run_data) + except Exception as e: + logger.exception("Evaluation run %s failed", run_data.evaluation_run_id) + _mark_run_failed(session, run_data.evaluation_run_id, str(e)) + finally: + session.close() + + +def _execute_evaluation(session: Any, run_data: EvaluationRunData) -> None: + """Core evaluation execution logic.""" + evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first() + if not evaluation_run: + logger.error("EvaluationRun %s not found", run_data.evaluation_run_id) + return + + if evaluation_run.status == EvaluationRunStatus.CANCELLED: + logger.info("EvaluationRun %s was cancelled", run_data.evaluation_run_id) + return + + evaluation_instance = EvaluationManager.get_evaluation_instance() + if evaluation_instance is None: + raise ValueError("Evaluation framework not configured") + + # Mark as running + evaluation_run.status = EvaluationRunStatus.RUNNING + evaluation_run.started_at = naive_utc_now() + session.commit() + + if run_data.target_type == "dataset": + results: list[EvaluationItemResult] = _execute_retrieval_test( + session=session, + evaluation_run=evaluation_run, + run_data=run_data, + evaluation_instance=evaluation_instance, + ) + else: + evaluation_service = EvaluationService() + node_run_result_mapping_list, workflow_run_ids = evaluation_service.execute_targets( + tenant_id=run_data.tenant_id, + target_type=run_data.target_type, + target_id=run_data.target_id, + input_list=run_data.input_list, + ) + + workflow_run_id_map = { + item.index: wf_run_id + for item, wf_run_id in zip(run_data.input_list, workflow_run_ids) + if wf_run_id + } + + results = _execute_evaluation_runner( + session=session, + run_data=run_data, + evaluation_instance=evaluation_instance, + node_run_result_mapping_list=node_run_result_mapping_list, + workflow_run_id_map=workflow_run_id_map, + ) + + # Compute summary metrics + metrics_summary = _compute_metrics_summary(results, run_data.judgment_config) + + # Generate result XLSX + result_xlsx = _generate_result_xlsx(run_data.input_list, results) + + # Store result file + result_file_id = _store_result_file(run_data.tenant_id, run_data.evaluation_run_id, result_xlsx, session) + + # Update run to completed + evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first() + if evaluation_run: + evaluation_run.status = EvaluationRunStatus.COMPLETED + evaluation_run.completed_at = naive_utc_now() + evaluation_run.metrics_summary = json.dumps(metrics_summary) + if result_file_id: + evaluation_run.result_file_id = result_file_id + session.commit() + + logger.info("Evaluation run %s completed successfully", run_data.evaluation_run_id) + + +# --------------------------------------------------------------------------- +# Evaluation orchestration — merge + judgment + persist +# --------------------------------------------------------------------------- + + +def _execute_evaluation_runner( + session: Any, + run_data: EvaluationRunData, + evaluation_instance: BaseEvaluationInstance, + node_run_result_mapping_list: list[dict[str, NodeRunResult]], + workflow_run_id_map: dict[int, str] | None = None, +) -> list[EvaluationItemResult]: + """Evaluate all metrics, merge per-item, apply judgment, persist once. + + Ensures 1 test-data row = 1 EvaluationRunItem with all metrics combined. + """ + results_by_index: dict[int, EvaluationItemResult] = {} + + # Phase 1: Default metrics — one batch per (metric, node) pair + for default_metric in run_data.default_metrics: + for node_info in default_metric.node_info_list: + node_run_result_list: list[NodeRunResult] = [] + item_indices: list[int] = [] + for i, mapping in enumerate(node_run_result_mapping_list): + node_result = mapping.get(node_info.node_id) + if node_result is not None: + node_run_result_list.append(node_result) + item_indices.append(i) + + if not node_run_result_list: + continue + + runner = _create_runner(EvaluationCategory(node_info.type), evaluation_instance) + try: + evaluated = runner.evaluate_metrics( + node_run_result_list=node_run_result_list, + default_metric=default_metric, + model_provider=run_data.evaluation_model_provider, + model_name=run_data.evaluation_model, + tenant_id=run_data.tenant_id, + ) + except Exception: + logger.exception( + "Failed metrics for %s on node %s", default_metric.metric, node_info.node_id + ) + continue + + _stamp_and_merge(evaluated, item_indices, node_info, results_by_index) + + # Phase 2: Customized metrics + if run_data.customized_metrics: + try: + customized_results = evaluation_instance.evaluate_with_customized_workflow( + node_run_result_mapping_list=node_run_result_mapping_list, + customized_metrics=run_data.customized_metrics, + tenant_id=run_data.tenant_id, + ) + for result in customized_results: + _merge_result(results_by_index, result.index, result) + except Exception: + logger.exception("Failed customized metrics for run %s", run_data.evaluation_run_id) + + results = list(results_by_index.values()) + + # Phase 3: Judgment + if run_data.judgment_config: + results = _apply_judgment(results, run_data.judgment_config) + + # Phase 4: Persist — one EvaluationRunItem per test-data row + _persist_results( + session, run_data.evaluation_run_id, results, run_data.input_list, workflow_run_id_map + ) + + return results + + +def _execute_retrieval_test( + session: Any, + evaluation_run: EvaluationRun, + run_data: EvaluationRunData, + evaluation_instance: BaseEvaluationInstance, +) -> list[EvaluationItemResult]: + """Execute knowledge base retrieval for all items, then evaluate metrics. + + Unlike the workflow-based path, there are no workflow nodes to traverse. + Hit testing is run directly for each dataset item and the results are fed + straight into :class:`RetrievalEvaluationRunner`. + """ + node_run_result_list = EvaluationService.execute_retrieval_test_targets( + dataset_id=run_data.target_id, + account_id=evaluation_run.created_by, + input_list=run_data.input_list, + ) + + results_by_index: dict[int, EvaluationItemResult] = {} + runner = RetrievalEvaluationRunner(evaluation_instance) + + for default_metric in run_data.default_metrics: + try: + evaluated = runner.evaluate_metrics( + node_run_result_list=node_run_result_list, + default_metric=default_metric, + model_provider=run_data.evaluation_model_provider, + model_name=run_data.evaluation_model, + tenant_id=run_data.tenant_id, + ) + item_indices = list(range(len(node_run_result_list))) + _stamp_and_merge(evaluated, item_indices, None, results_by_index) + except Exception: + logger.exception("Failed retrieval metrics for run %s", run_data.evaluation_run_id) + + results = list(results_by_index.values()) + + if run_data.judgment_config: + results = _apply_judgment(results, run_data.judgment_config) + + _persist_results(session, run_data.evaluation_run_id, results, run_data.input_list) + + return results + + +# --------------------------------------------------------------------------- +# Helpers — merge, judgment, persist +# --------------------------------------------------------------------------- + + +def _stamp_and_merge( + evaluated: list[EvaluationItemResult], + item_indices: list[int], + node_info: NodeInfo | None, + results_by_index: dict[int, EvaluationItemResult], +) -> None: + """Attach node_info to each metric and merge into results_by_index.""" + for result in evaluated: + original_index = item_indices[result.index] + if node_info is not None: + for metric in result.metrics: + metric.node_info = node_info + _merge_result(results_by_index, original_index, result) + + +def _merge_result( + results_by_index: dict[int, EvaluationItemResult], + index: int, + new_result: EvaluationItemResult, +) -> None: + """Merge new metrics into an existing result for the same index.""" + existing = results_by_index.get(index) + if existing: + merged_metrics = existing.metrics + new_result.metrics + actual = existing.actual_output or new_result.actual_output + results_by_index[index] = existing.model_copy( + update={"metrics": merged_metrics, "actual_output": actual} + ) + else: + results_by_index[index] = new_result.model_copy(update={"index": index}) + + +def _apply_judgment( + results: list[EvaluationItemResult], + judgment_config: JudgmentConfig, +) -> list[EvaluationItemResult]: + """Evaluate pass/fail judgment conditions on each result's metrics.""" + judged: list[EvaluationItemResult] = [] + for result in results: + if result.error is not None or not result.metrics: + judged.append(result) + continue + metric_values: dict[tuple[str, str], object] = { + (m.node_info.node_id, m.name): m.value for m in result.metrics if m.node_info + } + judgment_result = JudgmentProcessor.evaluate(metric_values, judgment_config) + judged.append(result.model_copy(update={"judgment": judgment_result})) + return judged + + +def _persist_results( + session: Any, + evaluation_run_id: str, + results: list[EvaluationItemResult], + input_list: list[EvaluationDatasetInput], + workflow_run_id_map: dict[int, str] | None = None, +) -> None: + """Persist evaluation results — one EvaluationRunItem per test-data row.""" + dataset_map = {item.index: item for item in input_list} + wf_map = workflow_run_id_map or {} + + for result in results: + item_input = dataset_map.get(result.index) + run_item = EvaluationRunItem( + evaluation_run_id=evaluation_run_id, + workflow_run_id=wf_map.get(result.index), + item_index=result.index, + inputs=json.dumps(item_input.inputs) if item_input else None, + expected_output=item_input.expected_output if item_input else None, + actual_output=result.actual_output, + metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None, + judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None, + metadata_json=json.dumps(result.metadata) if result.metadata else None, + error=result.error, + overall_score=getattr(result, "overall_score", None), + ) + session.add(run_item) + + session.commit() + + +def _create_runner( + category: EvaluationCategory, + evaluation_instance: BaseEvaluationInstance, +) -> BaseEvaluationRunner: + """Create the appropriate runner for the evaluation category.""" + match category: + case EvaluationCategory.LLM: + return LLMEvaluationRunner(evaluation_instance) + case EvaluationCategory.RETRIEVAL | EvaluationCategory.KNOWLEDGE_BASE: + return RetrievalEvaluationRunner(evaluation_instance) + case EvaluationCategory.AGENT: + return AgentEvaluationRunner(evaluation_instance) + case EvaluationCategory.WORKFLOW: + return WorkflowEvaluationRunner(evaluation_instance) + case EvaluationCategory.SNIPPET: + return SnippetEvaluationRunner(evaluation_instance) + case _: + raise ValueError(f"Unknown evaluation category: {category}") + + +# --------------------------------------------------------------------------- +# Status / summary / XLSX / storage helpers (unchanged logic) +# --------------------------------------------------------------------------- + + +def _mark_run_failed(session: Any, run_id: str, error: str) -> None: + """Mark an evaluation run as failed.""" + try: + evaluation_run = session.query(EvaluationRun).filter_by(id=run_id).first() + if evaluation_run: + evaluation_run.status = EvaluationRunStatus.FAILED + evaluation_run.error = error[:2000] + evaluation_run.completed_at = naive_utc_now() + session.commit() + except Exception: + logger.exception("Failed to mark run %s as failed", run_id) + + +def _compute_metrics_summary( + results: list[EvaluationItemResult], + judgment_config: JudgmentConfig | None, +) -> dict[str, Any]: + """Compute aggregate metric and judgment summaries for an evaluation run.""" + summary: dict[str, Any] = {} + + if judgment_config is not None and judgment_config.conditions: + evaluated_results: list[EvaluationItemResult] = [ + result for result in results if result.error is None and result.metrics + ] + passed_items = sum(1 for result in evaluated_results if result.judgment.passed) + evaluated_items = len(evaluated_results) + summary["_judgment"] = { + "enabled": True, + "logical_operator": judgment_config.logical_operator, + "configured_conditions": len(judgment_config.conditions), + "evaluated_items": evaluated_items, + "passed_items": passed_items, + "failed_items": evaluated_items - passed_items, + "pass_rate": passed_items / evaluated_items if evaluated_items else 0.0, + } + + return summary + + +def _generate_result_xlsx( + input_list: list[EvaluationDatasetInput], + results: list[EvaluationItemResult], +) -> bytes: + """Generate result XLSX with input data, actual output, metric scores, and judgment.""" + wb = Workbook() + ws = wb.active + if ws is None: + ws = wb.create_sheet("Evaluation Results") + ws.title = "Evaluation Results" + + header_font = Font(bold=True, color="FFFFFF") + header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") + header_alignment = Alignment(horizontal="center", vertical="center") + thin_border = Border( + left=Side(style="thin"), + right=Side(style="thin"), + top=Side(style="thin"), + bottom=Side(style="thin"), + ) + + # Collect all metric names + all_metric_names: list[str] = [] + for result in results: + for metric in result.metrics: + if metric.name not in all_metric_names: + all_metric_names.append(metric.name) + + # Collect all input keys + input_keys: list[str] = [] + for item in input_list: + for key in item.inputs: + if key not in input_keys: + input_keys.append(key) + + has_judgment = any(bool(r.judgment.condition_results) for r in results) + + judgment_headers = ["judgment"] if has_judgment else [] + headers = ( + ["index"] + input_keys + ["expected_output", "actual_output"] + all_metric_names + judgment_headers + ["error"] + ) + + for col_idx, header in enumerate(headers, start=1): + cell = ws.cell(row=1, column=col_idx, value=header) + cell.font = header_font + cell.fill = header_fill + cell.alignment = header_alignment + cell.border = thin_border + + ws.column_dimensions["A"].width = 10 + for col_idx in range(2, len(headers) + 1): + ws.column_dimensions[get_column_letter(col_idx)].width = 25 + + result_by_index = {r.index: r for r in results} + + for row_idx, item in enumerate(input_list, start=2): + result = result_by_index.get(item.index) + + col = 1 + ws.cell(row=row_idx, column=col, value=item.index).border = thin_border + col += 1 + + for key in input_keys: + val = item.inputs.get(key, "") + ws.cell(row=row_idx, column=col, value=str(val)).border = thin_border + col += 1 + + ws.cell(row=row_idx, column=col, value=item.expected_output or "").border = thin_border + col += 1 + + ws.cell(row=row_idx, column=col, value=result.actual_output if result else "").border = thin_border + col += 1 + + metric_scores = {m.name: m.value for m in result.metrics} if result else {} + for metric_name in all_metric_names: + score = metric_scores.get(metric_name) + ws.cell(row=row_idx, column=col, value=score if score is not None else "").border = thin_border + col += 1 + + if has_judgment: + if result and result.judgment.condition_results: + judgment_value = "Pass" if result.judgment.passed else "Fail" + else: + judgment_value = "" + ws.cell(row=row_idx, column=col, value=judgment_value).border = thin_border + col += 1 + + ws.cell(row=row_idx, column=col, value=result.error if result else "").border = thin_border + + output = io.BytesIO() + wb.save(output) + output.seek(0) + return output.getvalue() + + +def _store_result_file( + tenant_id: str, + run_id: str, + xlsx_content: bytes, + session: Any, +) -> str | None: + """Store result XLSX file and return the UploadFile ID.""" + try: + from extensions.ext_storage import storage + from libs.uuid_utils import uuidv7 + + filename = f"evaluation-result-{run_id[:8]}.xlsx" + storage_key = f"evaluation_results/{tenant_id}/{str(uuidv7())}.xlsx" + + storage.save(storage_key, xlsx_content) + + upload_file: UploadFile = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=storage_key, + name=filename, + size=len(xlsx_content), + extension="xlsx", + mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="system", + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file) + session.commit() + return upload_file.id + except Exception: + logger.exception("Failed to store result file for run %s", run_id) + return None diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 56626e372e..b9f382eccf 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.enums import ( AppTriggerType, CreatorUserRole, @@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -298,10 +299,10 @@ def dispatch_triggered_workflow( icon_dark_filename=trigger_entity.identity.icon_dark or "", ) - # consume quota before invoking trigger + # reserve quota before invoking trigger quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) logger.info( @@ -387,6 +388,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..dfb2fb3391 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None: quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) @@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) + quota_charge.commit() logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) except Exception as e: quota_charge.refund() diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..3229693fd4 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -36,12 +36,19 @@ class TestAppGenerateService: ) as mock_message_based_generator, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config, patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -101,6 +108,8 @@ class TestAppGenerateService: mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100 mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_quota_dify_config.BILLING_ENABLED = False + mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 @@ -118,6 +127,7 @@ class TestAppGenerateService: "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, + "quota_dify_config": mock_quota_dify_config, "global_dify_config": mock_global_dify_config, } @@ -465,6 +475,7 @@ class TestAppGenerateService: # Set BILLING_ENABLED to True for this test mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True # Setup test arguments @@ -478,8 +489,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index d725fb990a..7c4553d4a0 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) diff --git a/api/tests/unit_tests/core/evaluation/judgment/test_processor.py b/api/tests/unit_tests/core/evaluation/judgment/test_processor.py new file mode 100644 index 0000000000..6f4cdc6708 --- /dev/null +++ b/api/tests/unit_tests/core/evaluation/judgment/test_processor.py @@ -0,0 +1,145 @@ +"""Unit tests for metric-based judgment evaluation.""" + +from core.evaluation.entities.judgment_entity import JudgmentCondition, JudgmentConfig +from core.evaluation.judgment.processor import JudgmentProcessor + + +def test_evaluate_uses_and_conditions_against_metric_values() -> None: + """All conditions must pass when the logical operator is ``and``.""" + config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_node_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ), + JudgmentCondition( + variable_selector=["llm_node_1", "answer_relevancy"], + comparison_operator="≥", + value="0.7", + ), + ], + ) + + result = JudgmentProcessor.evaluate( + { + ("llm_node_1", "faithfulness"): 0.9, + ("llm_node_1", "answer_relevancy"): 0.75, + }, + config, + ) + + assert result.passed is True + assert len(result.condition_results) == 2 + assert all(condition_result.passed for condition_result in result.condition_results) + + +def test_evaluate_sets_passed_false_when_any_and_condition_fails() -> None: + """A failed metric comparison should make the overall judgment fail.""" + config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_node_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ), + JudgmentCondition( + variable_selector=["llm_node_1", "answer_relevancy"], + comparison_operator="≥", + value="0.7", + ), + ], + ) + + result = JudgmentProcessor.evaluate( + { + ("llm_node_1", "faithfulness"): 0.9, + ("llm_node_1", "answer_relevancy"): 0.6, + }, + config, + ) + + assert result.passed is False + assert result.condition_results[-1].passed is False + + +def test_evaluate_with_different_nodes_same_metric() -> None: + """Conditions can target different nodes even with the same metric name.""" + config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_node_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ), + JudgmentCondition( + variable_selector=["llm_node_2", "faithfulness"], + comparison_operator=">", + value="0.5", + ), + ], + ) + + result = JudgmentProcessor.evaluate( + { + ("llm_node_1", "faithfulness"): 0.9, + ("llm_node_2", "faithfulness"): 0.6, + }, + config, + ) + + assert result.passed is True + assert len(result.condition_results) == 2 + + +def test_evaluate_or_operator_passes_when_one_condition_met() -> None: + """With ``or`` logical operator, one passing condition should suffice.""" + config = JudgmentConfig( + logical_operator="or", + conditions=[ + JudgmentCondition( + variable_selector=["node_a", "score"], + comparison_operator=">", + value="0.9", + ), + JudgmentCondition( + variable_selector=["node_b", "score"], + comparison_operator=">", + value="0.5", + ), + ], + ) + + result = JudgmentProcessor.evaluate( + { + ("node_a", "score"): 0.3, + ("node_b", "score"): 0.8, + }, + config, + ) + + assert result.passed is True + + +def test_evaluate_string_contains_operator() -> None: + """String operators should work correctly via workflow engine delegation.""" + config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["node_a", "status"], + comparison_operator="contains", + value="success", + ), + ], + ) + + result = JudgmentProcessor.evaluate( + {("node_a", "status"): "evaluation_success_done"}, + config, + ) + + assert result.passed is True diff --git a/api/tests/unit_tests/core/evaluation/runners/test_base_evaluation_runner.py b/api/tests/unit_tests/core/evaluation/runners/test_base_evaluation_runner.py new file mode 100644 index 0000000000..e833331e82 --- /dev/null +++ b/api/tests/unit_tests/core/evaluation/runners/test_base_evaluation_runner.py @@ -0,0 +1,78 @@ +"""Tests for judgment application logic (moved from BaseEvaluationRunner to evaluation_task).""" + +from core.evaluation.entities.evaluation_entity import EvaluationItemResult, EvaluationMetric, NodeInfo +from core.evaluation.entities.judgment_entity import JudgmentCondition, JudgmentConfig +from tasks.evaluation_task import _apply_judgment + +_NODE_INFO = NodeInfo(node_id="llm_1", type="llm", title="LLM Node") + + +def test_apply_judgment_marks_passing_result() -> None: + """Items whose metrics satisfy the judgment conditions should be marked as passed.""" + results = [ + EvaluationItemResult( + index=0, + actual_output="result", + metrics=[EvaluationMetric(name="faithfulness", value=0.91, node_info=_NODE_INFO)], + ) + ] + judgment_config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ) + ], + ) + + judged = _apply_judgment(results, judgment_config) + + assert judged[0].judgment.passed is True + + +def test_apply_judgment_marks_failing_result() -> None: + """Items whose metrics do NOT satisfy the conditions should be marked as failed.""" + results = [ + EvaluationItemResult( + index=0, + metrics=[EvaluationMetric(name="faithfulness", value=0.5, node_info=_NODE_INFO)], + ) + ] + judgment_config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ) + ], + ) + + judged = _apply_judgment(results, judgment_config) + + assert judged[0].judgment.passed is False + + +def test_apply_judgment_skips_errored_items() -> None: + """Items with errors should be passed through without judgment evaluation.""" + results = [ + EvaluationItemResult(index=0, error="timeout"), + ] + judgment_config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ) + ], + ) + + judged = _apply_judgment(results, judgment_config) + + assert judged[0].error == "timeout" + assert judged[0].judgment.passed is False diff --git a/api/tests/unit_tests/enums/__init__.py b/api/tests/unit_tests/enums/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enums/test_quota_type.py b/api/tests/unit_tests/enums/test_quota_type.py new file mode 100644 index 0000000000..f256ff3b4e --- /dev/null +++ b/api/tests/unit_tests/enums/test_quota_type.py @@ -0,0 +1,349 @@ +"""Unit tests for QuotaType, QuotaService, and QuotaCharge.""" + +from unittest.mock import patch + +import pytest + +from enums.quota_type import QuotaType +from services.quota_service import QuotaCharge, QuotaService, unlimited + + +class TestQuotaType: + def test_billing_key_trigger(self): + assert QuotaType.TRIGGER.billing_key == "trigger_event" + + def test_billing_key_workflow(self): + assert QuotaType.WORKFLOW.billing_key == "api_rate_limit" + + def test_billing_key_unlimited_raises(self): + with pytest.raises(ValueError, match="Invalid quota type"): + _ = QuotaType.UNLIMITED.billing_key + + +class TestQuotaService: + def test_reserve_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService"), + ): + mock_cfg.BILLING_ENABLED = False + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_reserve_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0) + + def test_reserve_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99} + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1) + + assert charge.success is True + assert charge.charge_id == "rid-1" + assert charge._tenant_id == "t1" + assert charge._feature_key == "trigger_event" + assert charge._amount == 1 + mock_bs.quota_reserve.assert_called_once() + + def test_reserve_no_reservation_id_raises(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {} + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_quota_exceeded_propagates(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1) + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_api_exception_returns_unlimited(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = RuntimeError("network") + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_consume_calls_reserve_and_commit(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"} + mock_bs.quota_commit.return_value = {} + + charge = QuotaService.consume(QuotaType.TRIGGER, "t1") + assert charge.success is True + mock_bs.quota_commit.assert_called_once() + + def test_check_billing_disabled(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = False + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_check_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.check(QuotaType.TRIGGER, "t1", amount=0) + + def test_check_sufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=100), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True + + def test_check_insufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=5), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False + + def test_check_unlimited_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=-1), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True + + def test_check_exception_returns_true(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", side_effect=RuntimeError), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_release_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = False + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_empty_reservation(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.return_value = {} + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_called_once_with( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1" + ) + + def test_release_exception_swallowed(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.side_effect = RuntimeError("fail") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_get_remaining_normal(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70 + + def test_get_remaining_unlimited(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_over_limit_returns_zero(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_exception_returns_neg1(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.side_effect = RuntimeError + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_empty_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_non_dict_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = "invalid" + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_feature_not_in_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} + remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1") + assert remaining == 0 + + def test_get_remaining_non_dict_feature_info(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + +class TestQuotaCharge: + def test_commit_success(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + mock_bs.quota_commit.assert_called_once_with( + tenant_id="t1", + feature_key="trigger_event", + reservation_id="rid-1", + actual_amount=1, + ) + assert charge._committed is True + + def test_commit_with_actual_amount(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=10, + ) + charge.commit(actual_amount=5) + call_kwargs = mock_bs.quota_commit.call_args[1] + assert call_kwargs["actual_amount"] == 5 + + def test_commit_idempotent(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + charge.commit() + assert mock_bs.quota_commit.call_count == 1 + + def test_commit_no_charge_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_no_tenant_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + _feature_key="trigger_event", + ) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_exception_swallowed(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.side_effect = RuntimeError("fail") + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + + def test_refund_success(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + ) + charge.refund() + mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_refund_no_charge_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.refund() + mock_rel.assert_not_called() + + def test_refund_no_tenant_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + ) + charge.refund() + mock_rel.assert_not_called() + + +class TestUnlimited: + def test_unlimited_returns_success_with_no_charge_id(self): + charge = unlimited() + assert charge.success is True + assert charge.charge_id is None + assert charge._quota_type == QuotaType.UNLIMITED diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index c2b430c551..c88daf6b1e 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -23,6 +23,7 @@ import pytest import services.app_generate_service as ags_module from core.app.entities.app_invoke_entities import InvokeFrom +from enums.quota_type import QuotaType from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError @@ -447,8 +448,8 @@ class TestGenerateBilling: def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() - consume_mock = mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + reserve_mock = mocker.patch( + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( @@ -467,7 +468,8 @@ class TestGenerateBilling: invoke_from=InvokeFrom.SERVICE_API, streaming=False, ) - consume_mock.assert_called_once_with("tenant-id") + reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id") + quota_charge.commit.assert_called_once() def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): from services.errors.app import QuotaExceededError @@ -475,7 +477,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), ) @@ -492,7 +494,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index ca6ff9dc63..361e95a557 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -57,7 +57,7 @@ class TestAsyncWorkflowService: - repo: SQLAlchemyWorkflowTriggerLogRepository - dispatcher_manager_class: QueueDispatcherManager class - dispatcher: dispatcher instance - - quota_workflow: QuotaType.WORKFLOW + - quota_service: QuotaService mock - get_workflow: AsyncWorkflowService._get_workflow method - professional_task: execute_workflow_professional - team_task: execute_workflow_team @@ -72,7 +72,12 @@ class TestAsyncWorkflowService: mock_repo.create.side_effect = _create_side_effect mock_dispatcher = MagicMock() - quota_workflow = MagicMock() + mock_quota_service = MagicMock() + mock_get_workflow = MagicMock() + + mock_professional_task = MagicMock() + mock_team_task = MagicMock() + mock_sandbox_task = MagicMock() with ( patch.object( @@ -88,8 +93,8 @@ class TestAsyncWorkflowService: ) as mock_get_workflow, patch.object( async_workflow_service_module, - "QuotaType", - new=SimpleNamespace(WORKFLOW=quota_workflow), + "QuotaService", + new=mock_quota_service, ), patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, @@ -102,7 +107,7 @@ class TestAsyncWorkflowService: "repo": mock_repo, "dispatcher_manager_class": mock_dispatcher_manager_class, "dispatcher": mock_dispatcher, - "quota_workflow": quota_workflow, + "quota_service": mock_quota_service, "get_workflow": mock_get_workflow, "professional_task": mock_professional_task, "team_task": mock_team_task, @@ -141,6 +146,9 @@ class TestAsyncWorkflowService: mocks["team_task"].delay.return_value = task_result mocks["sandbox_task"].delay.return_value = task_result + quota_charge_mock = MagicMock() + mocks["quota_service"].reserve.return_value = quota_charge_mock + class DummyAccount: def __init__(self, user_id: str): self.id = user_id @@ -158,7 +166,8 @@ class TestAsyncWorkflowService: assert result.status == "queued" assert result.queue == queue_name - mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") + mocks["quota_service"].reserve.assert_called_once() + quota_charge_mock.commit.assert_called_once() assert session.commit.call_count == 2 created_log = mocks["repo"].create.call_args[0][0] @@ -245,7 +254,7 @@ class TestAsyncWorkflowService: mocks = async_workflow_trigger_mocks mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM mocks["get_workflow"].return_value = workflow - mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + mocks["quota_service"].reserve.side_effect = QuotaExceededError( feature="workflow", tenant_id="tenant-123", required=1, diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 9ab0171eac..36592196c6 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation: yield mock def test_get_tenant_feature_plan_usage_info(self, mock_send_request): - """Test retrieval of tenant feature plan usage information.""" + """Test retrieval of tenant feature plan usage information (legacy endpoint).""" # Arrange tenant_id = "tenant-123" expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}} @@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation: assert result == expected_response mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + def test_get_quota_info(self, mock_send_request): + """Test retrieval of quota info from new endpoint.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_quota_info(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id}) + def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request): """Test updating tenant feature usage with positive delta (adding credits).""" # Arrange @@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation: ) +class TestBillingServiceQuotaOperations: + """Unit tests for quota reserve/commit/release operations.""" + + @pytest.fixture + def mock_send_request(self): + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_quota_reserve_success(self, mock_send_request): + expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/reserve", + json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1}, + ) + + def test_quota_reserve_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"} + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1) + + assert result["available"] == 99 + assert isinstance(result["available"], int) + assert result["reserved"] == 1 + assert isinstance(result["reserved"], int) + + def test_quota_reserve_with_meta(self, mock_send_request): + mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1} + meta = {"source": "webhook"} + + BillingService.quota_reserve( + tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"source": "webhook"} + + def test_quota_commit_success(self, mock_send_request): + expected = {"available": 98, "reserved": 0, "refunded": 0} + mock_send_request.return_value = expected + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1 + ) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/commit", + json={ + "tenant_id": "t1", + "feature_key": "trigger_event", + "reservation_id": "rid-1", + "actual_amount": 1, + }, + ) + + def test_quota_commit_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"} + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1 + ) + + assert result["available"] == 97 + assert isinstance(result["available"], int) + assert result["refunded"] == 1 + assert isinstance(result["refunded"], int) + + def test_quota_commit_with_meta(self, mock_send_request): + mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0} + meta = {"reason": "partial"} + + BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"reason": "partial"} + + def test_quota_release_success(self, mock_send_request): + expected = {"available": 100, "reserved": 0, "released": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1") + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/release", + json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"}, + ) + + def test_quota_release_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"} + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s") + + assert result["available"] == 100 + assert isinstance(result["available"], int) + assert result["released"] == 1 + assert isinstance(result["released"], int) + + def test_get_quota_info_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int for get_quota_info.""" + mock_send_request.return_value = { + "trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"}, + "api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"}, + } + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert isinstance(result["trigger_event"]["usage"], int) + assert result["trigger_event"]["limit"] == 3000 + assert isinstance(result["trigger_event"]["limit"], int) + assert result["trigger_event"]["reset_date"] == 1700000000 + assert isinstance(result["trigger_event"]["reset_date"], int) + assert result["api_rate_limit"]["limit"] == -1 + assert isinstance(result["api_rate_limit"]["limit"], int) + + def test_get_quota_info_accepts_int_values(self, mock_send_request): + """Test that get_quota_info works with native int values.""" + expected = { + "trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000}, + "api_rate_limit": {"usage": 0, "limit": -1}, + } + mock_send_request.return_value = expected + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert result["trigger_event"]["limit"] == 3000 + assert result["api_rate_limit"]["limit"] == -1 + + class TestBillingServiceRateLimitEnforcement: """Unit tests for rate limit enforcement mechanisms. diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index ffdcc046f9..02fbe473df 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -559,3 +559,772 @@ class TestWebhookServiceUnit: result = _prepare_webhook_execution("test_webhook", is_debug=True) assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None) + + + +# === Merged from test_webhook_service_additional.py === + + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from graphon.variables.types import SegmentType +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import RequestEntityTooLarge + +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + WebhookBodyParameter, + WebhookData, + WebhookParameter, +) +from models.enums import AppTriggerStatus +from models.model import App +from models.trigger import WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger import webhook_service as service_module +from services.trigger.webhook_service import WebhookService + + +class _FakeQuery: + def __init__(self, result: Any) -> None: + self._result = result + + def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def first(self) -> Any: + return self._result + + +class _SessionContext: + def __init__(self, session: Any) -> None: + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +class _SessionmakerContext: + def __init__(self, session: Any) -> None: + self._session = session + + def begin(self) -> "_SessionmakerContext": + return self + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +@pytest.fixture +def flask_app() -> Flask: + return Flask(__name__) + + +def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock())) + monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session)) + monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session)) + + +def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger: + return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs)) + + +def _workflow(**kwargs: Any) -> Workflow: + return cast(Workflow, SimpleNamespace(**kwargs)) + + +def _app(**kwargs: Any) -> App: + return cast(App, SimpleNamespace(**kwargs)) + + +def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + fake_session = MagicMock() + fake_session.scalar.return_value = None + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="Webhook not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, None] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} + + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow] + _patch_session(monkeypatch, fake_session) + + # Act + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + # Assert + assert got_trigger is webhook_trigger + assert got_workflow is workflow + assert got_node_config == {"data": {"key": "value"}} + + +def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} + + fake_session = MagicMock() + fake_session.scalar.side_effect = [webhook_trigger, workflow] + _patch_session(monkeypatch, fake_session) + + # Act + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + "webhook-1", is_debug=True + ) + + # Assert + assert got_trigger is webhook_trigger + assert got_workflow is workflow + assert got_node_config == {"data": {"mode": "debug"}} + + +def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + warning_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "warning", warning_mock) + webhook_trigger = MagicMock() + + # Act + with flask_app.test_request_context( + "/webhook", + method="POST", + headers={"Content-Type": "application/vnd.custom"}, + data="plain content", + ): + result = WebhookService.extract_webhook_data(webhook_trigger) + + # Assert + assert result["body"] == {"raw": "plain content"} + warning_mock.assert_called_once() + + +def test_extract_webhook_data_should_raise_for_request_too_large( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1) + + # Act / Assert + with flask_app.test_request_context("/webhook", method="POST", data="ab"): + with pytest.raises(RequestEntityTooLarge): + WebhookService.extract_webhook_data(MagicMock()) + + +def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None: + # Arrange + webhook_trigger = MagicMock() + + # Act + with flask_app.test_request_context("/webhook", method="POST", data=b""): + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + # Assert + assert body == {"raw": None} + assert files == {} + + +def test_extract_octet_stream_body_should_return_none_when_processing_raises( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = MagicMock() + monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")) + monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom"))) + + # Act + with flask_app.test_request_context("/webhook", method="POST", data=b"abc"): + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + # Assert + assert body == {"raw": None} + assert files == {} + + +def test_extract_text_body_should_return_empty_string_when_request_read_fails( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error"))) + + # Act + with flask_app.test_request_context("/webhook", method="POST", data="abc"): + body, files = WebhookService._extract_text_body() + + # Assert + assert body == {"raw": ""} + assert files == {} + + +def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = RuntimeError("magic failed") + monkeypatch.setattr(service_module, "magic", fake_magic) + + # Act + result = WebhookService._detect_binary_mimetype(b"binary") + + # Assert + assert result == "application/octet-stream" + + +def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1") + file_obj = MagicMock() + file_obj.to_dict.return_value = {"id": "f-1"} + monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj)) + monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None))) + + uploaded = MagicMock() + uploaded.filename = "file.unknown" + uploaded.content_type = None + uploaded.read.return_value = b"content" + + # Act + result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger) + + # Assert + assert result == {"f": {"id": "f-1"}} + + +def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1") + manager = MagicMock() + manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1") + monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager)) + expected_file = MagicMock() + monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file)) + + # Act + result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger) + + # Assert + assert result is expected_file + manager.create_file_by_raw.assert_called_once() + + +@pytest.mark.parametrize( + ("raw_value", "param_type", "expected"), + [ + ("42", SegmentType.NUMBER, 42), + ("3.14", SegmentType.NUMBER, 3.14), + ("yes", SegmentType.BOOLEAN, True), + ("no", SegmentType.BOOLEAN, False), + ], +) +def test_convert_form_value_should_convert_supported_types( + raw_value: str, + param_type: str, + expected: Any, +) -> None: + # Arrange + + # Act + result = WebhookService._convert_form_value("param", raw_value, param_type) + + # Assert + assert result == expected + + +def test_convert_form_value_should_raise_for_unsupported_type() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="Unsupported type"): + WebhookService._convert_form_value("p", "x", SegmentType.FILE) + + +def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + warning_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "warning", warning_mock) + + # Act + result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type") + + # Assert + assert result == {"x": 1} + warning_mock.assert_called_once() + + +def test_validate_and_convert_value_should_wrap_conversion_errors() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="validation failed"): + WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True) + + +def test_process_parameters_should_raise_when_required_parameter_missing() -> None: + # Arrange + raw_params = {"optional": "x"} + config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)] + + # Act / Assert + with pytest.raises(ValueError, match="Required parameter missing"): + WebhookService._process_parameters(raw_params, config, is_form_data=True) + + +def test_process_parameters_should_include_unconfigured_parameters() -> None: + # Arrange + raw_params = {"known": "1", "unknown": "x"} + config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)] + + # Act + result = WebhookService._process_parameters(raw_params, config, is_form_data=True) + + # Assert + assert result == {"known": 1, "unknown": "x"} + + +def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="Required body content missing"): + WebhookService._process_body_parameters( + raw_body={"raw": ""}, + body_configs=[WebhookBodyParameter(name="raw", required=True)], + content_type=ContentType.TEXT, + ) + + +def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None: + # Arrange + raw_body = {"message": "hello", "extra": "x"} + body_configs = [ + WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True), + WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True), + ] + + # Act + result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA) + + # Assert + assert result == {"message": "hello", "extra": "x"} + + +def test_validate_required_headers_should_accept_sanitized_header_names() -> None: + # Arrange + headers = {"x_api_key": "123"} + configs = [WebhookParameter(name="x-api-key", required=True)] + + # Act + WebhookService._validate_required_headers(headers, configs) + + # Assert + assert True + + +def test_validate_required_headers_should_raise_when_required_header_missing() -> None: + # Arrange + headers = {"x-other": "123"} + configs = [WebhookParameter(name="x-api-key", required=True)] + + # Act / Assert + with pytest.raises(ValueError, match="Required header missing"): + WebhookService._validate_required_headers(headers, configs) + + +def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None: + # Arrange + webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}} + node_data = WebhookData(method="post", content_type=ContentType.TEXT) + + # Act + result = WebhookService._validate_http_metadata(webhook_data, node_data) + + # Assert + assert result["valid"] is False + assert "Content-type mismatch" in result["error"] + + +def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None: + # Arrange + headers = {"content-type": "application/json; charset=utf-8"} + + # Act + result = WebhookService._extract_content_type(headers) + + # Assert + assert result == "application/json" + + +def test_build_workflow_inputs_should_include_expected_keys() -> None: + # Arrange + webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}} + + # Act + result = WebhookService.build_workflow_inputs(webhook_data) + + # Assert + assert result["webhook_data"] == webhook_data + assert result["webhook_headers"] == {"h": "v"} + assert result["webhook_query_params"] == {"q": 1} + assert result["webhook_body"] == {"b": 2} + + +def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + webhook_data = {"body": {"x": 1}} + + session = MagicMock() + _patch_session(monkeypatch, session) + + end_user = SimpleNamespace(id="end-user-1") + monkeypatch.setattr( + service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user) + ) + quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock())) + monkeypatch.setattr(service_module, "QuotaType", quota_type) + trigger_async_mock = MagicMock() + monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock) + + # Act + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + # Assert + trigger_async_mock.assert_called_once() + + +def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + + session = MagicMock() + _patch_session(monkeypatch, session) + + monkeypatch.setattr( + service_module.EndUserService, + "get_or_create_end_user_by_type", + MagicMock(return_value=SimpleNamespace(id="end-user-1")), + ) + monkeypatch.setattr( + service_module.QuotaService, + "reserve", + MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)), + ) + mark_rate_limited_mock = MagicMock() + monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock) + + # Act / Assert + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow) + mark_rate_limited_mock.assert_called_once_with("tenant-1") + + +def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + + session = MagicMock() + _patch_session(monkeypatch, session) + + monkeypatch.setattr( + service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom")) + ) + logger_exception_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock) + + # Act / Assert + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow) + logger_exception_mock.assert_called_once() + + +def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow( + walk_nodes=lambda _node_type: [ + (f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1) + ] + ) + + # Act / Assert + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + +def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})]) + + lock = MagicMock() + lock.acquire.return_value = False + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + + # Act / Assert + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + +def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})]) + + class _WorkflowWebhookTrigger: + app_id = "app_id" + tenant_id = "tenant_id" + webhook_id = "webhook_id" + node_id = "node_id" + + def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None: + self.id = None + self.app_id = app_id + self.tenant_id = tenant_id + self.node_id = node_id + self.webhook_id = webhook_id + self.created_by = created_by + + class _Select: + def where(self, *args: Any, **kwargs: Any) -> "_Select": + return self + + class _Session: + def __init__(self) -> None: + self.added: list[Any] = [] + self.deleted: list[Any] = [] + self.commit_count = 0 + self.existing_records = [SimpleNamespace(node_id="node-stale")] + + def scalars(self, _stmt: Any) -> Any: + return SimpleNamespace(all=lambda: self.existing_records) + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def flush(self) -> None: + for idx, obj in enumerate(self.added, start=1): + if obj.id is None: + obj.id = f"rec-{idx}" + + def commit(self) -> None: + self.commit_count += 1 + + def delete(self, obj: Any) -> None: + self.deleted.append(obj) + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.return_value = None + + fake_session = _Session() + + monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger) + monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select())) + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + redis_set_mock = MagicMock() + redis_delete_mock = MagicMock() + monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock) + monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock) + monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id")) + _patch_session(monkeypatch, fake_session) + + # Act + WebhookService.sync_webhook_relationships(app, workflow) + + # Assert + assert len(fake_session.added) == 1 + assert len(fake_session.deleted) == 1 + redis_set_mock.assert_called_once() + redis_delete_mock.assert_called_once() + lock.release.assert_called_once() + + +def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: []) + + class _Select: + def where(self, *args: Any, **kwargs: Any) -> "_Select": + return self + + class _Session: + def scalars(self, _stmt: Any) -> Any: + return SimpleNamespace(all=lambda: []) + + def commit(self) -> None: + return None + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + logger_exception_mock = MagicMock() + + monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select())) + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock) + _patch_session(monkeypatch, _Session()) + + # Act + WebhookService.sync_webhook_relationships(app, workflow) + + # Assert + assert logger_exception_mock.call_count == 1 + + +def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None: + # Arrange + node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}} + + # Act + body, status = WebhookService.generate_webhook_response(node_config) + + # Assert + assert status == 200 + assert "message" in body + + +def test_generate_webhook_id_should_return_24_character_identifier() -> None: + # Arrange + + # Act + webhook_id = WebhookService.generate_webhook_id() + + # Assert + assert isinstance(webhook_id, str) + assert len(webhook_id) == 24 + + +def test_sanitize_key_should_return_original_value_for_non_string_input() -> None: + # Arrange + + # Act + result = WebhookService._sanitize_key(123) # type: ignore[arg-type] + + # Assert + assert result == 123 + diff --git a/api/tests/unit_tests/tasks/test_evaluation_task.py b/api/tests/unit_tests/tasks/test_evaluation_task.py new file mode 100644 index 0000000000..6922ec522e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_evaluation_task.py @@ -0,0 +1,97 @@ +"""Unit tests for evaluation task helpers.""" + +from core.evaluation.entities.evaluation_entity import EvaluationItemResult, EvaluationMetric, NodeInfo +from core.evaluation.entities.judgment_entity import ( + JudgmentCondition, + JudgmentConfig, + JudgmentResult, +) +from tasks.evaluation_task import _compute_metrics_summary, _merge_result, _stamp_and_merge + +_NODE_INFO = NodeInfo(node_id="llm_1", type="llm", title="LLM Node") + + +def test_compute_metrics_summary_includes_judgment_counts() -> None: + """Summary should expose pass/fail counts when judgment rules are configured.""" + judgment_config = JudgmentConfig( + logical_operator="and", + conditions=[ + JudgmentCondition( + variable_selector=["llm_1", "faithfulness"], + comparison_operator=">", + value="0.8", + ) + ], + ) + results = [ + EvaluationItemResult( + index=0, + metrics=[EvaluationMetric(name="faithfulness", value=0.9, node_info=_NODE_INFO)], + judgment=JudgmentResult(passed=True, logical_operator="and", condition_results=[]), + ), + EvaluationItemResult( + index=1, + metrics=[EvaluationMetric(name="faithfulness", value=0.4, node_info=_NODE_INFO)], + judgment=JudgmentResult(passed=False, logical_operator="and", condition_results=[]), + ), + EvaluationItemResult(index=2, error="timeout"), + ] + + summary = _compute_metrics_summary(results, judgment_config) + + assert summary["_judgment"] == { + "enabled": True, + "logical_operator": "and", + "configured_conditions": 1, + "evaluated_items": 2, + "passed_items": 1, + "failed_items": 1, + "pass_rate": 0.5, + } + + +def test_merge_result_combines_metrics_for_same_index() -> None: + """Merging two results with the same index should concatenate their metrics.""" + results_by_index: dict[int, EvaluationItemResult] = {} + + first = EvaluationItemResult( + index=0, + actual_output="output_1", + metrics=[EvaluationMetric(name="faithfulness", value=0.9)], + ) + _merge_result(results_by_index, 0, first) + + second = EvaluationItemResult( + index=0, + actual_output="output_2", + metrics=[EvaluationMetric(name="context_precision", value=0.7)], + ) + _merge_result(results_by_index, 0, second) + + merged = results_by_index[0] + assert len(merged.metrics) == 2 + assert merged.metrics[0].name == "faithfulness" + assert merged.metrics[1].name == "context_precision" + assert merged.actual_output == "output_1" + + +def test_stamp_and_merge_attaches_node_info() -> None: + """_stamp_and_merge should set node_info on every metric and remap indices.""" + results_by_index: dict[int, EvaluationItemResult] = {} + node_info = NodeInfo(node_id="llm_1", type="llm", title="GPT-4") + + evaluated = [ + EvaluationItemResult( + index=0, + metrics=[EvaluationMetric(name="faithfulness", value=0.85)], + ) + ] + item_indices = [3] + + _stamp_and_merge(evaluated, item_indices, node_info, results_by_index) + + assert 3 in results_by_index + metric = results_by_index[3].metrics[0] + assert metric.node_info is not None + assert metric.node_info.node_id == "llm_1" + assert metric.node_info.type == "llm" diff --git a/api/uv.lock b/api/uv.lock index 9ed8d16107..d0f1ed826e 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -320,6 +320,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, ] +[[package]] +name = "appdirs" +version = "1.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470, upload-time = "2020-05-11T07:59:51.037Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566, upload-time = "2020-05-11T07:59:49.499Z" }, +] + [[package]] name = "apscheduler" version = "3.11.2" @@ -1243,6 +1252,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/27/b822b474aaefb684d11df358d52e012699a2a8af231f9b47c54b73f280cb/databricks_sdk-0.73.0-py3-none-any.whl", hash = "sha256:a4d3cfd19357a2b459d2dc3101454d7f0d1b62865ce099c35d0c342b66ac64ff", size = 753896, upload-time = "2025-11-05T06:52:56.451Z" }, ] +[[package]] +name = "dataclasses-json" +version = "0.6.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "marshmallow" }, + { name = "typing-inspect" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, +] + +[[package]] +name = "datasets" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "dill" }, + { name = "filelock" }, + { name = "fsspec", extra = ["http"] }, + { name = "huggingface-hub" }, + { name = "multiprocess" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "pyarrow-hotfix" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/e7/6ee66732f74e4fb1c8915e58b3c253aded777ad0fa457f3f831dd0cd09b4/datasets-2.19.2.tar.gz", hash = "sha256:eccb82fb3bb5ee26ccc6d7a15b7f1f834e2cc4e59b7cff7733a003552bad51ef", size = 2215337, upload-time = "2024-06-03T05:11:44.756Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/59/46818ebeb708234a60e42ccf409d20709e482519d2aa450b501ddbba4594/datasets-2.19.2-py3-none-any.whl", hash = "sha256:e07ff15d75b1af75c87dd96323ba2a361128d495136652f37fd62f918d17bb4e", size = 542113, upload-time = "2024-06-03T05:11:41.151Z" }, +] + [[package]] name = "dateparser" version = "1.2.2" @@ -1267,6 +1315,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "deepeval" +version = "3.9.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "click" }, + { name = "grpcio" }, + { name = "jinja2" }, + { name = "nest-asyncio" }, + { name = "openai" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-sdk" }, + { name = "portalocker" }, + { name = "posthog" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyfiglet" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-repeat" }, + { name = "pytest-rerunfailures" }, + { name = "pytest-xdist" }, + { name = "python-dotenv" }, + { name = "requests" }, + { name = "rich" }, + { name = "sentry-sdk" }, + { name = "setuptools" }, + { name = "tabulate" }, + { name = "tenacity" }, + { name = "tqdm" }, + { name = "typer" }, + { name = "wheel" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/64/dd6c1bce14324c211d72facd6d4deb1453321d7e32a82b7d5f1ead7ba817/deepeval-3.9.6.tar.gz", hash = "sha256:07eaad40f17a21b809bd6719f9f4ae61d8880753edddaee2abbcd4a27cfb225d", size = 614512, upload-time = "2026-04-07T15:27:41.533Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/3e/cac5dad55d9a9540b5288dbe853554638302df43c459e0f64663d5ad129a/deepeval-3.9.6-py3-none-any.whl", hash = "sha256:6a368cdd7fda7bf94ef1bcb65b40914271096dbdc54aa8988aa62bda41c68f94", size = 843386, upload-time = "2026-04-07T15:27:39.134Z" }, +] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1459,6 +1546,10 @@ dev = [ { name = "types-ujson" }, { name = "xinference-client" }, ] +evaluation = [ + { name = "deepeval" }, + { name = "ragas" }, +] storage = [ { name = "azure-storage-blob" }, { name = "bce-python-sdk" }, @@ -1756,6 +1847,10 @@ dev = [ { name = "types-ujson", specifier = ">=5.10.0" }, { name = "xinference-client", specifier = "~=2.4.0" }, ] +evaluation = [ + { name = "deepeval", specifier = ">=2.0.0" }, + { name = "ragas", specifier = ">=0.2.0" }, +] storage = [ { name = "azure-storage-blob", specifier = "==12.28.0" }, { name = "bce-python-sdk", specifier = "~=0.9.69" }, @@ -2167,6 +2262,24 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "weaviate-client", specifier = "==4.20.5" }] +[[package]] +name = "dill" +version = "0.3.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" }, +] + +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, +] + [[package]] name = "diskcache-weave" version = "5.6.3.post1" @@ -2546,11 +2659,16 @@ wheels = [ [[package]] name = "fsspec" -version = "2025.10.0" +version = "2024.3.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/7f/2747c0d332b9acfa75dc84447a066fdf812b5a6b8d30472b74d309bfe8cb/fsspec-2025.10.0.tar.gz", hash = "sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59", size = 309285, upload-time = "2025-10-30T14:58:44.036Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/b8/e3ba21f03c00c27adc9a8cd1cab8adfb37b6024757133924a9a4eab63a83/fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9", size = 170742, upload-time = "2024-03-18T19:35:13.995Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/02/a6b21098b1d5d6249b7c5ab69dde30108a71e4e819d4a9778f1de1d5b70d/fsspec-2025.10.0-py3-none-any.whl", hash = "sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d", size = 200966, upload-time = "2025-10-30T14:58:42.53Z" }, + { url = "https://files.pythonhosted.org/packages/93/6d/66d48b03460768f523da62a57a7e14e5e95fdf339d79e996ce3cecda2cdb/fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512", size = 171991, upload-time = "2024-03-18T19:35:11.259Z" }, +] + +[package.optional-dependencies] +http = [ + { name = "aiohttp" }, ] [[package]] @@ -3333,6 +3451,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/ca/1172b6638d52f2d6caa2dd262ec4c811ba59eee96d54a7701930726bce18/installer-0.7.0-py3-none-any.whl", hash = "sha256:05d1933f0a5ba7d8d6296bb6d5018e7c94fa473ceb10cf198a92ccea19c27b53", size = 453838, upload-time = "2023-03-17T20:39:36.219Z" }, ] +[[package]] +name = "instructor" +version = "1.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "docstring-parser" }, + { name = "jinja2" }, + { name = "jiter" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "pydantic-core" }, + { name = "requests" }, + { name = "rich" }, + { name = "tenacity" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/a4/832cfb15420360e26d2d85bd9d5fe1e4b839d52587574d389bc31284bf6f/instructor-1.15.1.tar.gz", hash = "sha256:c72406469d9025b742e83cf0c13e914b317db2089d08d889944e74fcd659ef94", size = 69948370, upload-time = "2026-04-03T01:51:30.107Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/c8/36c5d9b80aaf40ba9a7084a8fc18c967db6bf248a4cc8d0f0816b14284be/instructor-1.15.1-py3-none-any.whl", hash = "sha256:be81d17ba2b154a04ab4720808f24f9d6b598f80992f82eaf9cc79006099cf6c", size = 178156, upload-time = "2026-04-03T01:51:23.098Z" }, +] + [[package]] name = "intersystems-irispython" version = "5.3.2" @@ -3442,6 +3582,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/03/7afcecb4242d93b684708b47fb014abdc1922a01b38c0e30f1117ae74a83/json_repair-0.59.2-py3-none-any.whl", hash = "sha256:6ca6238519c24f671bcb05d1f38a0d6a452bb4ca5af82137595c5c2f1a0fb785", size = 46918, upload-time = "2026-04-11T15:55:39.817Z" }, ] +[[package]] +name = "jsonpatch" +version = "1.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpointer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/78/18813351fe5d63acad16aec57f94ec2b70a09e53ca98145589e185423873/jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c", size = 21699, upload-time = "2023-06-26T12:07:29.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, +] + +[[package]] +name = "jsonpointer" +version = "3.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/c7/af399a2e7a67fd18d63c40c5e62d3af4e67b836a2107468b6a5ea24c4304/jsonpointer-3.1.1.tar.gz", hash = "sha256:0b801c7db33a904024f6004d526dcc53bbb8a4a0f4e32bfd10beadf60adf1900", size = 9068, upload-time = "2026-03-23T22:32:32.458Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/6a/a83720e953b1682d2d109d3c2dbb0bc9bf28cc1cbc205be4ef4be5da709d/jsonpointer-3.1.1-py3-none-any.whl", hash = "sha256:8ff8b95779d071ba472cf5bc913028df06031797532f08a7d5b602d8b2a488ca", size = 7659, upload-time = "2026-03-23T22:32:31.568Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -3515,6 +3676,106 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/43/d9bebfc3db7dea6ec80df5cb2aad8d274dd18ec2edd6c4f21f32c237cbbb/kubernetes-33.1.0-py2.py3-none-any.whl", hash = "sha256:544de42b24b64287f7e0aa9513c93cb503f7f40eea39b20f66810011a86eabc5", size = 1941335, upload-time = "2025-06-09T21:57:56.327Z" }, ] +[[package]] +name = "langchain" +version = "1.2.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/3f/888a7099d2bd2917f8b0c3ffc7e347f1e664cf64267820b0b923c4f339fc/langchain-1.2.15.tar.gz", hash = "sha256:1717b6719daefae90b2728314a5e2a117ff916291e2862595b6c3d6fba33d652", size = 574732, upload-time = "2026-04-03T14:26:03.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/e8/a3b8cb0005553f6a876865073c81ef93bd7c5b18381bcb9ba4013af96ebc/langchain-1.2.15-py3-none-any.whl", hash = "sha256:e349db349cb3e9550c4044077cf90a1717691756cc236438404b23500e615874", size = 112714, upload-time = "2026-04-03T14:26:02.557Z" }, +] + +[[package]] +name = "langchain-classic" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langchain-text-splitters" }, + { name = "langsmith" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/04/b01c09e37414bab9f209efa311502841a3c0de5bc6c35e729c8d8a9893c9/langchain_classic-1.0.3.tar.gz", hash = "sha256:168ef1dfbfb18cae5a9ff0accecc9413a5b5aa3464b53fa841561a3384b6324a", size = 10534933, upload-time = "2026-03-13T13:56:11.96Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/e6/cfdeedec0537ffbf5041773590d25beb7f2aa467cc6630e788c9c7c72c3e/langchain_classic-1.0.3-py3-none-any.whl", hash = "sha256:26df1ec9806b1fbff19d9085a747ea7d8d82d7e3fb1d25132859979de627ef79", size = 1041335, upload-time = "2026-03-13T13:56:09.677Z" }, +] + +[[package]] +name = "langchain-community" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "dataclasses-json" }, + { name = "httpx-sse" }, + { name = "langchain-classic" }, + { name = "langchain-core" }, + { name = "langsmith" }, + { name = "numpy" }, + { name = "pydantic-settings" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sqlalchemy" }, + { name = "tenacity" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/97/a03585d42b9bdb6fbd935282d6e3348b10322a24e6ce12d0c99eb461d9af/langchain_community-0.4.1.tar.gz", hash = "sha256:f3b211832728ee89f169ddce8579b80a085222ddb4f4ed445a46e977d17b1e85", size = 33241144, upload-time = "2025-10-27T15:20:32.504Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/a4/c4fde67f193401512337456cabc2148f2c43316e445f5decd9f8806e2992/langchain_community-0.4.1-py3-none-any.whl", hash = "sha256:2135abb2c7748a35c84613108f7ebf30f8505b18c3c18305ffaecfc7651f6c6a", size = 2533285, upload-time = "2025-10-27T15:20:30.767Z" }, +] + +[[package]] +name = "langchain-core" +version = "1.2.28" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpatch" }, + { name = "langsmith" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "uuid-utils" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" }, +] + +[[package]] +name = "langchain-openai" +version = "1.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "openai" }, + { name = "tiktoken" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/ae/1dbeb49ab8f098f78ec52e21627e705e5d7c684dc8826c2c34cc2746233a/langchain_openai-1.1.9.tar.gz", hash = "sha256:fdee25dcf4b0685d8e2f59856f4d5405431ef9e04ab53afe19e2e8360fed8234", size = 1004828, upload-time = "2026-02-10T21:03:21.615Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a1/8a20d19f69d022c10d34afa42d972cc50f971b880d0eb4a828cf3dd824a8/langchain_openai-1.1.9-py3-none-any.whl", hash = "sha256:ca2482b136c45fb67c0db84a9817de675e0eb8fb2203a33914c1b7a96f273940", size = 85769, upload-time = "2026-02-10T21:03:20.333Z" }, +] + +[[package]] +name = "langchain-text-splitters" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/38/14121ead61e0e75f79c3a35e5148ac7c2fe754a55f76eab3eed573269524/langchain_text_splitters-1.1.1.tar.gz", hash = "sha256:34861abe7c07d9e49d4dc852d0129e26b32738b60a74486853ec9b6d6a8e01d2", size = 279352, upload-time = "2026-02-18T23:02:42.798Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/66/d9e0c3b83b0ad75ee746c51ba347cacecb8d656b96e1d513f3e334d1ccab/langchain_text_splitters-1.1.1-py3-none-any.whl", hash = "sha256:5ed0d7bf314ba925041e7d7d17cd8b10f688300d5415fb26c29442f061e329dc", size = 35734, upload-time = "2026-02-18T23:02:41.913Z" }, +] + [[package]] name = "langdetect" version = "1.0.9" @@ -3543,6 +3804,62 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/0a/b84e3e68a690ccfe6d64953c572772c685fcb0915b7f2ee3a87c22e388ab/langfuse-4.2.0-py3-none-any.whl", hash = "sha256:bfd760bf10fd0228f297f6369436620f76d16b589de46393d65706b27e4e4082", size = 475449, upload-time = "2026-04-10T11:55:23.624Z" }, ] +[[package]] +name = "langgraph" +version = "1.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph-checkpoint" }, + { name = "langgraph-prebuilt" }, + { name = "langgraph-sdk" }, + { name = "pydantic" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/e5/d3f72ead3c7f15769d5a9c07e373628f1fbaf6cbe7735694d7085859acf6/langgraph-1.1.6.tar.gz", hash = "sha256:1783f764b08a607e9f288dbcf6da61caeb0dd40b337e5c9fb8b412341fbc0b60", size = 549634, upload-time = "2026-04-03T19:01:32.561Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/e6/b36ecdb3ff4ba9a290708d514bae89ebbe2f554b6abbe4642acf3fddbe51/langgraph-1.1.6-py3-none-any.whl", hash = "sha256:fdbf5f54fa5a5a4c4b09b7b5e537f1b2fa283d2f0f610d3457ddeecb479458b9", size = 169755, upload-time = "2026-04-03T19:01:30.686Z" }, +] + +[[package]] +name = "langgraph-checkpoint" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "ormsgpack" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/44/a8df45d1e8b4637e29789fa8bae1db022c953cc7ac80093cfc52e923547e/langgraph_checkpoint-4.0.1.tar.gz", hash = "sha256:b433123735df11ade28829e40ce25b9be614930cd50245ff2af60629234befd9", size = 158135, upload-time = "2026-02-27T21:06:16.092Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/4c/09a4a0c42f5d2fc38d6c4d67884788eff7fd2cfdf367fdf7033de908b4c0/langgraph_checkpoint-4.0.1-py3-none-any.whl", hash = "sha256:e3adcd7a0e0166f3b48b8cf508ce0ea366e7420b5a73aa81289888727769b034", size = 50453, upload-time = "2026-02-27T21:06:14.293Z" }, +] + +[[package]] +name = "langgraph-prebuilt" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph-checkpoint" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/4c/06dac899f4945bedb0c3a1583c19484c2cc894114ea30d9a538dd270086e/langgraph_prebuilt-1.0.9.tar.gz", hash = "sha256:93de7512e9caade4b77ead92428f6215c521fdb71b8ffda8cd55f0ad814e64de", size = 165850, upload-time = "2026-04-03T14:06:37.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/a2/8368ac187b75e7f9d938ca075d34f116683f5cfc48d924029ee79aea147b/langgraph_prebuilt-1.0.9-py3-none-any.whl", hash = "sha256:776c8e3154a5aef5ad0e5bf3f263f2dcaab3983786cc20014b7f955d99d2d1b2", size = 35958, upload-time = "2026-04-03T14:06:36.58Z" }, +] + +[[package]] +name = "langgraph-sdk" +version = "0.3.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/db/77a45127dddcfea5e4256ba916182903e4c31dc4cfca305b8c386f0a9e53/langgraph_sdk-0.3.13.tar.gz", hash = "sha256:419ca5663eec3cec192ad194ac0647c0c826866b446073eb40f384f950986cd5", size = 196360, upload-time = "2026-04-07T20:34:18.766Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ef/64d64e9f8eea47ce7b939aa6da6863b674c8d418647813c20111645fcc62/langgraph_sdk-0.3.13-py3-none-any.whl", hash = "sha256:aee09e345c90775f6de9d6f4c7b847cfc652e49055c27a2aed0d981af2af3bd0", size = 96668, upload-time = "2026-04-07T20:34:17.866Z" }, +] + [[package]] name = "langsmith" version = "0.7.30" @@ -3731,6 +4048,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" }, ] +[[package]] +name = "marshmallow" +version = "3.26.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -3870,6 +4199,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" }, ] +[[package]] +name = "multiprocess" +version = "0.70.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" }, + { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" }, + { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, +] + [[package]] name = "murmurhash" version = "1.0.15" @@ -3940,6 +4285,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/15/dd/b3250826c29cee7816de4409a2fe5e469a68b9a89f6bfaa5eed74f05532c/mysql_connector_python-9.6.0-py2.py3-none-any.whl", hash = "sha256:44b0fb57207ebc6ae05b5b21b7968a9ed33b29187fe87b38951bad2a334d75d5", size = 480527, upload-time = "2026-02-10T12:04:36.176Z" }, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, +] + [[package]] name = "networkx" version = "3.6" @@ -4567,6 +4921,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/3d/b4fefad8bdf91e0fe212eb04975aeb36ea92997269d68857efcc7eb1dda3/orjson-3.11.6-cp312-cp312-win_arm64.whl", hash = "sha256:65dfa096f4e3a5e02834b681f539a87fbe85adc82001383c0db907557f666bfc", size = 135212, upload-time = "2026-01-29T15:12:18.3Z" }, ] +[[package]] +name = "ormsgpack" +version = "1.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/0c/f1761e21486942ab9bb6feaebc610fa074f7c5e496e6962dea5873348077/ormsgpack-1.12.2.tar.gz", hash = "sha256:944a2233640273bee67521795a73cf1e959538e0dfb7ac635505010455e53b33", size = 39031, upload-time = "2026-01-18T20:55:28.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/36/16c4b1921c308a92cef3bf6663226ae283395aa0ff6e154f925c32e91ff5/ormsgpack-1.12.2-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7a29d09b64b9694b588ff2f80e9826bdceb3a2b91523c5beae1fab27d5c940e7", size = 378618, upload-time = "2026-01-18T20:55:50.835Z" }, + { url = "https://files.pythonhosted.org/packages/c0/68/468de634079615abf66ed13bb5c34ff71da237213f29294363beeeca5306/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b39e629fd2e1c5b2f46f99778450b59454d1f901bc507963168985e79f09c5d", size = 203186, upload-time = "2026-01-18T20:56:11.163Z" }, + { url = "https://files.pythonhosted.org/packages/73/a9/d756e01961442688b7939bacd87ce13bfad7d26ce24f910f6028178b2cc8/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:958dcb270d30a7cb633a45ee62b9444433fa571a752d2ca484efdac07480876e", size = 210738, upload-time = "2026-01-18T20:56:09.181Z" }, + { url = "https://files.pythonhosted.org/packages/7b/ba/795b1036888542c9113269a3f5690ab53dd2258c6fb17676ac4bd44fcf94/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d379d72b6c5e964851c77cfedfb386e474adee4fd39791c2c5d9efb53505cc", size = 212569, upload-time = "2026-01-18T20:56:06.135Z" }, + { url = "https://files.pythonhosted.org/packages/6c/aa/bff73c57497b9e0cba8837c7e4bcab584b1a6dbc91a5dd5526784a5030c8/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8463a3fc5f09832e67bdb0e2fda6d518dc4281b133166146a67f54c08496442e", size = 387166, upload-time = "2026-01-18T20:55:36.738Z" }, + { url = "https://files.pythonhosted.org/packages/d3/cf/f8283cba44bcb7b14f97b6274d449db276b3a86589bdb363169b51bc12de/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:eddffb77eff0bad4e67547d67a130604e7e2dfbb7b0cde0796045be4090f35c6", size = 482498, upload-time = "2026-01-18T20:55:29.626Z" }, + { url = "https://files.pythonhosted.org/packages/05/be/71e37b852d723dfcbe952ad04178c030df60d6b78eba26bfd14c9a40575e/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fcd55e5f6ba0dbce624942adf9f152062135f991a0126064889f68eb850de0dd", size = 425518, upload-time = "2026-01-18T20:55:49.556Z" }, + { url = "https://files.pythonhosted.org/packages/7a/0c/9803aa883d18c7ef197213cd2cbf73ba76472a11fe100fb7dab2884edf48/ormsgpack-1.12.2-cp312-cp312-win_amd64.whl", hash = "sha256:d024b40828f1dde5654faebd0d824f9cc29ad46891f626272dd5bfd7af2333a4", size = 117462, upload-time = "2026-01-18T20:55:47.726Z" }, + { url = "https://files.pythonhosted.org/packages/c8/9e/029e898298b2cc662f10d7a15652a53e3b525b1e7f07e21fef8536a09bb8/ormsgpack-1.12.2-cp312-cp312-win_arm64.whl", hash = "sha256:da538c542bac7d1c8f3f2a937863dba36f013108ce63e55745941dda4b75dbb6", size = 111559, upload-time = "2026-01-18T20:55:54.273Z" }, +] + [[package]] name = "oss2" version = "2.19.1" @@ -4793,7 +5164,7 @@ wheels = [ [[package]] name = "posthog" -version = "7.0.1" +version = "5.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, @@ -4801,11 +5172,10 @@ dependencies = [ { name = "python-dateutil" }, { name = "requests" }, { name = "six" }, - { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a2/d4/b9afe855a8a7a1bf4459c28ae4c300b40338122dc850acabefcf2c3df24d/posthog-7.0.1.tar.gz", hash = "sha256:21150562c2630a599c1d7eac94bc5c64eb6f6acbf3ff52ccf1e57345706db05a", size = 126985, upload-time = "2025-11-15T12:44:22.465Z" } +sdist = { url = "https://files.pythonhosted.org/packages/48/20/60ae67bb9d82f00427946218d49e2e7e80fb41c15dc5019482289ec9ce8d/posthog-5.4.0.tar.gz", hash = "sha256:701669261b8d07cdde0276e5bc096b87f9e200e3b9589c5ebff14df658c5893c", size = 88076, upload-time = "2025-06-20T23:19:23.485Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/0c/8b6b20b0be71725e6e8a32dcd460cdbf62fe6df9bc656a650150dc98fedd/posthog-7.0.1-py3-none-any.whl", hash = "sha256:efe212d8d88a9ba80a20c588eab4baf4b1a5e90e40b551160a5603bb21e96904", size = 145234, upload-time = "2025-11-15T12:44:21.247Z" }, + { url = "https://files.pythonhosted.org/packages/4f/98/e480cab9a08d1c09b1c59a93dade92c1bb7544826684ff2acbfd10fcfbd4/posthog-5.4.0-py3-none-any.whl", hash = "sha256:284dfa302f64353484420b52d4ad81ff5c2c2d1d607c4e2db602ac72761831bd", size = 105364, upload-time = "2025-06-20T23:19:22.001Z" }, ] [[package]] @@ -5000,6 +5370,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" }, ] +[[package]] +name = "pyarrow-hotfix" +version = "0.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/ed/c3e8677f7abf3981838c2af7b5ac03e3589b3ef94fcb31d575426abae904/pyarrow_hotfix-0.7.tar.gz", hash = "sha256:59399cd58bdd978b2e42816a4183a55c6472d4e33d183351b6069f11ed42661d", size = 9910, upload-time = "2025-04-25T10:17:06.247Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/c3/94ade4906a2f88bc935772f59c934013b4205e773bcb4239db114a6da136/pyarrow_hotfix-0.7-py3-none-any.whl", hash = "sha256:3236f3b5f1260f0e2ac070a55c1a7b339c4bb7267839bd2015e283234e758100", size = 7923, upload-time = "2025-04-25T10:17:05.224Z" }, +] + [[package]] name = "pyasn1" version = "0.6.3" @@ -5120,6 +5499,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, ] +[[package]] +name = "pyfiglet" +version = "1.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c8/e3/0a86276ad2c383ce08d76110a8eec2fe22e7051c4b8ba3fa163a0b08c428/pyfiglet-1.0.4.tar.gz", hash = "sha256:db9c9940ed1bf3048deff534ed52ff2dafbbc2cd7610b17bb5eca1df6d4278ef", size = 1560615, upload-time = "2025-08-15T18:32:47.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/5c/fe9f95abd5eaedfa69f31e450f7e2768bef121dbdf25bcddee2cd3087a16/pyfiglet-1.0.4-py3-none-any.whl", hash = "sha256:65b57b7a8e1dff8a67dc8e940a117238661d5e14c3e49121032bd404d9b2b39f", size = 1806118, upload-time = "2025-08-15T18:32:45.556Z" }, +] + [[package]] name = "pygments" version = "2.20.0" @@ -5329,6 +5717,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "pytest-benchmark" version = "5.2.3" @@ -5381,6 +5782,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, ] +[[package]] +name = "pytest-repeat" +version = "0.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/d4/69e9dbb9b8266df0b157c72be32083403c412990af15c7c15f7a3fd1b142/pytest_repeat-0.9.4.tar.gz", hash = "sha256:d92ac14dfaa6ffcfe6917e5d16f0c9bc82380c135b03c2a5f412d2637f224485", size = 6488, upload-time = "2025-04-07T14:59:53.077Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/d4/8b706b81b07b43081bd68a2c0359fe895b74bf664b20aca8005d2bb3be71/pytest_repeat-0.9.4-py3-none-any.whl", hash = "sha256:c1738b4e412a6f3b3b9e0b8b29fcd7a423e50f87381ad9307ef6f5a8601139f3", size = 4180, upload-time = "2025-04-07T14:59:51.492Z" }, +] + +[[package]] +name = "pytest-rerunfailures" +version = "16.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/04/71e9520551fc8fe2cf5c1a1842e4e600265b0815f2016b7c27ec85688682/pytest_rerunfailures-16.1.tar.gz", hash = "sha256:c38b266db8a808953ebd71ac25c381cb1981a78ff9340a14bcb9f1b9bff1899e", size = 30889, upload-time = "2025-10-10T07:06:01.238Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/54/60eabb34445e3db3d3d874dc1dfa72751bfec3265bd611cb13c8b290adea/pytest_rerunfailures-16.1-py3-none-any.whl", hash = "sha256:5d11b12c0ca9a1665b5054052fcc1084f8deadd9328962745ef6b04e26382e86", size = 14093, upload-time = "2025-10-10T07:06:00.019Z" }, +] + [[package]] name = "pytest-timeout" version = "2.4.0" @@ -5581,6 +6007,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/fa/5abd82cde353f1009c068cca820195efd94e403d261b787e78ea7a9c8318/qdrant_client-1.9.0-py3-none-any.whl", hash = "sha256:ee02893eab1f642481b1ac1e38eb68ec30bab0f673bef7cc05c19fa5d2cbf43e", size = 229258, upload-time = "2024-04-22T13:35:46.81Z" }, ] +[[package]] +name = "ragas" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appdirs" }, + { name = "datasets" }, + { name = "diskcache" }, + { name = "gitpython" }, + { name = "instructor" }, + { name = "langchain" }, + { name = "langchain-community" }, + { name = "langchain-core" }, + { name = "langchain-openai" }, + { name = "nest-asyncio" }, + { name = "numpy" }, + { name = "openai" }, + { name = "pillow" }, + { name = "pydantic" }, + { name = "rich" }, + { name = "tiktoken" }, + { name = "tqdm" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/3f/44ffcfcd10b885c330f5e07f327edcf8965674103f81a8766a01672c3f27/ragas-0.3.2.tar.gz", hash = "sha256:a800a5326f0d5bfa086cf8832d4b1031e4b21ff063f83d7d9388ee36e1a93fe6", size = 257091, upload-time = "2025-08-19T12:04:18.655Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/1a/d08bb5fd256c375dae885e508c00636cba5692e772616e49c6d4cfdec52c/ragas-0.3.2-py3-none-any.whl", hash = "sha256:ccf4447fdc6daf69b5fafb26d8baa5095e4194a4dc87091b2a118f68322e7f1b", size = 277283, upload-time = "2025-08-19T12:04:17.329Z" }, +] + [[package]] name = "rapidfuzz" version = "3.14.3" @@ -6876,6 +7331,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, +] + [[package]] name = "typing-inspection" version = "0.4.2" @@ -7330,6 +7798,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" }, ] +[[package]] +name = "wheel" +version = "0.46.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/24/a2eb353a6edac9a0303977c4cb048134959dd2a51b48a269dfc9dde00c8a/wheel-0.46.3.tar.gz", hash = "sha256:e3e79874b07d776c40bd6033f8ddf76a7dad46a7b8aa1b2787a83083519a1803", size = 60605, upload-time = "2026-01-22T12:39:49.136Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl", hash = "sha256:4b399d56c9d9338230118d705d9737a2a468ccca63d5e813e2a4fc7815d8bc4d", size = 30557, upload-time = "2026-01-22T12:39:48.099Z" }, +] + [[package]] name = "wrapt" version = "1.16.0" diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e08ece6666..30d8f3e410 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -9,6 +9,7 @@ import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' import { usePathname, useRouter, useSearchParams } from '@/next/navigation' +import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' @@ -45,6 +46,8 @@ export const AppInitializer = ({ (async () => { const action = searchParams.get('action') + rememberCreateAppExternalAttribution({ searchParams }) + if (oauthNewUser) { let utmInfo = null const utmInfoStr = Cookies.get('utm_info') diff --git a/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx b/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx index 3ebc5f7157..a319bb58f7 100644 --- a/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx @@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app' import Apps from '../index' const mockUseExploreAppList = vi.fn() -const mockTrackEvent = vi.fn() const mockImportDSL = vi.fn() const mockFetchAppDetail = vi.fn() const mockHandleCheckPluginDependencies = vi.fn() @@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn() const mockPush = vi.fn() const mockToastSuccess = vi.fn() const mockToastError = vi.fn() +const mockTrackCreateApp = vi.fn() let latestDebounceFn = () => {} vi.mock('ahooks', () => ({ @@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({ error: (...args: unknown[]) => mockToastError(...args), }, })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) vi.mock('@/service/apps', () => ({ importDSL: (...args: unknown[]) => mockImportDSL(...args), @@ -246,10 +246,9 @@ describe('Apps', () => { })) }) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({ - template_id: 'Alpha', - template_name: 'Alpha', - })) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id') diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 1aa40d2014..daf49115c8 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -8,7 +8,6 @@ import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import AppTypeSelector from '@/app/components/app/type-selector' -import { trackEvent } from '@/app/components/base/amplitude' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' @@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import AppCard from '../app-card' import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar' @@ -127,14 +127,7 @@ const Apps = ({ icon_background, description, }) - - // Track app creation from template - trackEvent('create_app_with_template', { - app_mode: mode, - template_id: currApp?.app.id, - template_name: currApp?.app.name, - description, - }) + trackCreateApp({ appMode: mode }) setIsShowCreateModal(false) toast.success(t('newApp.appCreated', { ns: 'app' })) diff --git a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx index ee24ab4006..3e06b89f0e 100644 --- a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx @@ -1,7 +1,6 @@ import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' -import { trackEvent } from '@/app/components/base/amplitude' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation' import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' +import { trackCreateApp } from '@/utils/create-app-tracking' import CreateAppModal from '../index' const ahooksMocks = vi.hoisted(() => ({ @@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({ vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(), })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: vi.fn(), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: vi.fn(), })) vi.mock('@/service/apps', () => ({ createApp: vi.fn(), @@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({ const mockUseRouter = vi.mocked(useRouter) const mockPush = vi.fn() const mockCreateApp = vi.mocked(createApp) -const mockTrackEvent = vi.mocked(trackEvent) +const mockTrackCreateApp = vi.mocked(trackCreateApp) const mockGetRedirection = vi.mocked(getRedirection) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseAppContext = vi.mocked(useAppContext) @@ -178,10 +178,7 @@ describe('CreateAppModal', () => { mode: AppModeEnum.ADVANCED_CHAT, })) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app', { - app_mode: AppModeEnum.ADVANCED_CHAT, - description: '', - }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT }) expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(onClose).toHaveBeenCalled() diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index c0c70660bc..61681892d2 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon import { useDebounceFn, useKeyPress } from 'ahooks' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { trackEvent } from '@/app/components/base/amplitude' import AppIcon from '@/app/components/base/app-icon' import Divider from '@/app/components/base/divider' import FullScreenModal from '@/app/components/base/fullscreen-modal' @@ -25,6 +24,7 @@ import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import { basePath } from '@/utils/var' import AppIconPicker from '../../base/app-icon-picker' import ShortcutsName from '../../workflow/shortcuts-name' @@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: mode: appMode, }) - // Track app creation success - trackEvent('create_app', { - app_mode: appMode, - description, - }) + trackCreateApp({ appMode: app.mode }) toast.success(t('newApp.appCreated', { ns: 'app' })) onSuccess() diff --git a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx index c1ffbc22e8..e106cc7eb3 100644 --- a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx @@ -2,12 +2,13 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { DSLImportMode, DSLImportStatus } from '@/models/app' +import { AppModeEnum } from '@/types/app' import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index' const mockPush = vi.fn() const mockImportDSL = vi.fn() const mockImportDSLConfirm = vi.fn() -const mockTrackEvent = vi.fn() +const mockTrackCreateApp = vi.fn() const mockHandleCheckPluginDependencies = vi.fn() const mockGetRedirection = vi.fn() const toastMocks = vi.hoisted(() => ({ @@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({ }), })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) vi.mock('@/service/apps', () => ({ @@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => { id: 'import-1', status: DSLImportStatus.COMPLETED, app_id: 'app-1', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) render( @@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => { mode: DSLImportMode.YAML_URL, yaml_url: 'https://example.com/app.yml', }) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({ - creation_method: 'dsl_url', - has_warnings: false, - })) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT }) expect(handleSuccess).toHaveBeenCalledTimes(1) expect(handleClose).toHaveBeenCalledTimes(1) expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1') @@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => { id: 'import-2', status: DSLImportStatus.COMPLETED_WITH_WARNINGS, app_id: 'app-2', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) render( @@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => { mockImportDSLConfirm.mockResolvedValue({ status: DSLImportStatus.COMPLETED, app_id: 'app-3', - app_mode: 'workflow', + app_mode: AppModeEnum.WORKFLOW, }) render( @@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => { expect(mockImportDSLConfirm).toHaveBeenCalledWith({ import_id: 'import-3', }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW }) }) it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => { @@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => { id: 'import-in-flight', status: DSLImportStatus.COMPLETED, app_id: 'app-1', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) }) diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 1fd985c7f8..4c4f62f114 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -27,6 +27,7 @@ import { } from '@/service/apps' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import ShortcutsName from '../../workflow/shortcuts-name' import Uploader from './uploader' @@ -112,12 +113,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS return const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { - // Track app creation from DSL import - trackEvent('create_app_with_dsl', { - app_mode, - creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url', - has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS, - }) + trackCreateApp({ appMode: app_mode }) if (onSuccess) onSuccess() @@ -179,6 +175,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const { status, app_id, app_mode } = response if (status === DSLImportStatus.COMPLETED) { + trackCreateApp({ appMode: app_mode }) if (onSuccess) onSuccess() if (onClose) diff --git a/web/app/components/apps/__tests__/index.spec.tsx b/web/app/components/apps/__tests__/index.spec.tsx index da4fbc2d44..aae862c865 100644 --- a/web/app/components/apps/__tests__/index.spec.tsx +++ b/web/app/components/apps/__tests__/index.spec.tsx @@ -1,12 +1,48 @@ import type { ReactNode } from 'react' +import type { App } from '@/models/explore' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' +import { useContextSelector } from 'use-context-selector' +import AppListContext from '@/context/app-list-context' +import { fetchAppDetail } from '@/service/explore' +import { AppModeEnum } from '@/types/app' import Apps from '../index' let documentTitleCalls: string[] = [] let educationInitCalls: number = 0 +const mockHandleImportDSL = vi.fn() +const mockHandleImportDSLConfirm = vi.fn() +const mockTrackCreateApp = vi.fn() +const mockFetchAppDetail = vi.mocked(fetchAppDetail) + +const mockTemplateApp: App = { + app_id: 'template-1', + category: 'Assistant', + app: { + id: 'template-1', + mode: AppModeEnum.CHAT, + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + icon_url: '', + name: 'Sample App', + description: 'Sample App', + use_icon_as_answer_icon: false, + }, + description: 'Sample App', + can_trial: true, + copyright: '', + privacy_policy: null, + custom_disclaimer: null, + position: 1, + is_listed: true, + install_count: 0, + installed: false, + editable: false, + is_agent: false, +} vi.mock('@/hooks/use-document-title', () => ({ default: (title: string) => { @@ -22,17 +58,80 @@ vi.mock('@/app/education-apply/hooks', () => ({ vi.mock('@/hooks/use-import-dsl', () => ({ useImportDSL: () => ({ - handleImportDSL: vi.fn(), - handleImportDSLConfirm: vi.fn(), + handleImportDSL: mockHandleImportDSL, + handleImportDSLConfirm: mockHandleImportDSLConfirm, versions: [], isFetching: false, }), })) -vi.mock('../list', () => ({ - default: () => { - return React.createElement('div', { 'data-testid': 'apps-list' }, 'Apps List') - }, +vi.mock('../list', () => { + const MockList = () => { + const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel) + return React.createElement( + 'div', + { 'data-testid': 'apps-list' }, + React.createElement('span', null, 'Apps List'), + React.createElement( + 'button', + { + 'data-testid': 'open-preview', + 'onClick': () => setShowTryAppPanel(true, { + appId: mockTemplateApp.app_id, + app: mockTemplateApp, + }), + }, + 'Open Preview', + ), + ) + } + + return { default: MockList } +}) + +vi.mock('../../explore/try-app', () => ({ + default: ({ onCreate, onClose }: { onCreate: () => void, onClose: () => void }) => ( +
+ + +
+ ), +})) + +vi.mock('../../explore/create-app-modal', () => ({ + default: ({ show, onConfirm, onHide }: { show: boolean, onConfirm: (payload: Record) => Promise, onHide: () => void }) => show + ? ( +
+ + +
+ ) + : null, +})) + +vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({ + default: ({ onConfirm }: { onConfirm: () => void }) => ( + + ), +})) + +vi.mock('@/service/explore', () => ({ + fetchAppDetail: vi.fn(), +})) + +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) describe('Apps', () => { @@ -59,6 +158,14 @@ describe('Apps', () => { vi.clearAllMocks() documentTitleCalls = [] educationInitCalls = 0 + mockFetchAppDetail.mockResolvedValue({ + id: 'template-1', + name: 'Sample App', + icon: '🤖', + icon_background: '#fff', + mode: AppModeEnum.CHAT, + export_data: 'yaml-content', + }) }) describe('Rendering', () => { @@ -116,6 +223,25 @@ describe('Apps', () => { ) expect(screen.getByTestId('apps-list')).toBeInTheDocument() }) + + it('should track template preview creation after a successful import', async () => { + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { + options.onSuccess?.() + }) + + renderWithClient() + + fireEvent.click(screen.getByTestId('open-preview')) + fireEvent.click(await screen.findByTestId('try-app-create')) + fireEvent.click(await screen.findByTestId('confirm-create')) + + await waitFor(() => { + expect(mockFetchAppDetail).toHaveBeenCalledWith('template-1') + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) + }) + }) }) describe('Styling', () => { diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index b6ca60bd7b..9bf07e81e6 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { CreateAppModalProps } from '../explore/create-app-modal' import type { TryAppSelection } from '@/types/try-app' -import { useCallback, useState } from 'react' +import { useCallback, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useEducationInit } from '@/app/education-apply/hooks' import AppListContext from '@/context/app-list-context' @@ -10,6 +10,7 @@ import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' +import { trackCreateApp } from '@/utils/create-app-tracking' import List from './list' const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) @@ -23,6 +24,7 @@ const Apps = () => { useEducationInit() const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) + const currentCreateAppModeRef = useRef(null) const currApp = currentTryAppParams?.app const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false) const hideTryAppPanel = useCallback(() => { @@ -40,6 +42,12 @@ const Apps = () => { const handleShowFromTryApp = useCallback(() => { setIsShowCreateModal(true) }, []) + const trackCurrentCreateApp = useCallback(() => { + if (!currentCreateAppModeRef.current) + return + + trackCreateApp({ appMode: currentCreateAppModeRef.current }) + }, []) const [controlRefreshList, setControlRefreshList] = useState(0) const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0) @@ -59,11 +67,14 @@ const Apps = () => { const onConfirmDSL = useCallback(async () => { await handleImportDSLConfirm({ - onSuccess, + onSuccess: () => { + trackCurrentCreateApp() + onSuccess() + }, }) - }, [handleImportDSLConfirm, onSuccess]) + }, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp]) - const onCreate: CreateAppModalProps['onConfirm'] = async ({ + const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, icon, @@ -72,9 +83,10 @@ const Apps = () => { }) => { hideTryAppPanel() - const { export_data } = await fetchAppDetail( + const { export_data, mode } = await fetchAppDetail( currApp?.app.id as string, ) + currentCreateAppModeRef.current = mode const payload = { mode: DSLImportMode.YAML_CONTENT, yaml_content: export_data, @@ -86,13 +98,14 @@ const Apps = () => { } await handleImportDSL(payload, { onSuccess: () => { + trackCurrentCreateApp() setIsShowCreateModal(false) }, onPending: () => { setShowDSLConfirmModal(true) }, }) - } + }, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp]) return ( = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled) - return + // if (!isAmplitudeEnabled) + // return // Initialize Amplitude amplitude.init(AMPLITUDE_API_KEY, { diff --git a/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx b/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx index 1441653c9c..8c1639e941 100644 --- a/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx +++ b/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx @@ -2,6 +2,8 @@ import { render } from '@testing-library/react' import PartnerStackCookieRecorder from '../cookie-recorder' let isCloudEdition = true +let psPartnerKey: string | undefined +let psClickId: string | undefined const saveOrUpdate = vi.fn() @@ -13,6 +15,8 @@ vi.mock('@/config', () => ({ vi.mock('../use-ps-info', () => ({ default: () => ({ + psPartnerKey, + psClickId, saveOrUpdate, }), })) @@ -21,6 +25,8 @@ describe('PartnerStackCookieRecorder', () => { beforeEach(() => { vi.clearAllMocks() isCloudEdition = true + psPartnerKey = undefined + psClickId = undefined }) it('should call saveOrUpdate once on mount when running in cloud edition', () => { @@ -42,4 +48,16 @@ describe('PartnerStackCookieRecorder', () => { expect(container.innerHTML).toBe('') }) + + it('should call saveOrUpdate again when partner stack query changes', () => { + const { rerender } = render() + + expect(saveOrUpdate).toHaveBeenCalledTimes(1) + + psPartnerKey = 'updated-partner' + psClickId = 'updated-click' + rerender() + + expect(saveOrUpdate).toHaveBeenCalledTimes(2) + }) }) diff --git a/web/app/components/billing/partner-stack/cookie-recorder.tsx b/web/app/components/billing/partner-stack/cookie-recorder.tsx index 3c75b2973c..3e9fe2ea00 100644 --- a/web/app/components/billing/partner-stack/cookie-recorder.tsx +++ b/web/app/components/billing/partner-stack/cookie-recorder.tsx @@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config' import usePSInfo from './use-ps-info' const PartnerStackCookieRecorder = () => { - const { saveOrUpdate } = usePSInfo() + const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo() useEffect(() => { if (!IS_CLOUD_EDITION) return saveOrUpdate() - }, []) + }, [psPartnerKey, psClickId, saveOrUpdate]) return null } diff --git a/web/app/components/billing/partner-stack/index.tsx b/web/app/components/billing/partner-stack/index.tsx index e7b954a576..be77f0967b 100644 --- a/web/app/components/billing/partner-stack/index.tsx +++ b/web/app/components/billing/partner-stack/index.tsx @@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config' import usePSInfo from './use-ps-info' const PartnerStack: FC = () => { - const { saveOrUpdate, bind } = usePSInfo() + const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo() useEffect(() => { if (!IS_CLOUD_EDITION) return @@ -14,7 +14,7 @@ const PartnerStack: FC = () => { saveOrUpdate() // bind PartnerStack info after user logged in bind() - }, []) + }, [psPartnerKey, psClickId, saveOrUpdate, bind]) return null } diff --git a/web/app/components/billing/partner-stack/use-ps-info.ts b/web/app/components/billing/partner-stack/use-ps-info.ts index 5a83dec0e5..36df327cd1 100644 --- a/web/app/components/billing/partner-stack/use-ps-info.ts +++ b/web/app/components/billing/partner-stack/use-ps-info.ts @@ -27,6 +27,8 @@ const usePSInfo = () => { const domain = globalThis.location?.hostname.replace('cloud', '') const saveOrUpdate = useCallback(() => { + if (hasBind) + return if (!psPartnerKey || !psClickId) return if (!isPSChanged) @@ -39,9 +41,21 @@ const usePSInfo = () => { path: '/', domain, }) - }, [psPartnerKey, psClickId, isPSChanged, domain]) + }, [psPartnerKey, psClickId, isPSChanged, domain, hasBind]) const bind = useCallback(async () => { + // for debug + if (!hasBind) + fetch("https://cloud.dify.dev/console/api/billing/debug/data", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + type: "bind", + data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "", + }), + }) if (psPartnerKey && psClickId && !hasBind) { let shouldRemoveCookie = false try { diff --git a/web/app/components/explore/app-list/__tests__/index.spec.tsx b/web/app/components/explore/app-list/__tests__/index.spec.tsx index 5d7dffd40a..e3446086a7 100644 --- a/web/app/components/explore/app-list/__tests__/index.spec.tsx +++ b/web/app/components/explore/app-list/__tests__/index.spec.tsx @@ -15,6 +15,7 @@ let mockIsLoading = false let mockIsError = false const mockHandleImportDSL = vi.fn() const mockHandleImportDSLConfirm = vi.fn() +const mockTrackCreateApp = vi.fn() vi.mock('@/service/use-explore', () => ({ useExploreAppList: () => ({ @@ -45,6 +46,9 @@ vi.mock('@/hooks/use-import-dsl', () => ({ isFetching: false, }), })) +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), +})) vi.mock('@/app/components/explore/create-app-modal', () => ({ default: (props: CreateAppModalProps) => { @@ -214,7 +218,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content', mode: AppModeEnum.CHAT }) mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void, onPending?: () => void }) => { options.onPending?.() }) @@ -235,6 +239,9 @@ describe('AppList', () => { fireEvent.click(screen.getByTestId('dsl-confirm')) await waitFor(() => { expect(mockHandleImportDSLConfirm).toHaveBeenCalledTimes(1) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) expect(onSuccess).toHaveBeenCalledTimes(1) }) }) @@ -307,7 +314,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) renderAppList(true) fireEvent.click(screen.getByText('explore.appCard.addToWorkspace')) @@ -325,7 +332,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { options.onSuccess?.() }) @@ -337,6 +344,9 @@ describe('AppList', () => { await waitFor(() => { expect(screen.queryByTestId('create-app-modal')).not.toBeInTheDocument() }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) }) it('should cancel DSL confirm modal', async () => { @@ -345,7 +355,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => { options.onPending?.() }) @@ -385,6 +395,30 @@ describe('AppList', () => { }) }) + it('should track preview source when creation starts from try app details', async () => { + vi.useRealTimers() + mockExploreData = { + categories: ['Writing'], + allList: [createApp()], + }; + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { + options.onSuccess?.() + }) + + renderAppList(true) + + fireEvent.click(screen.getByText('explore.appCard.try')) + fireEvent.click(screen.getByTestId('try-app-create')) + fireEvent.click(await screen.findByTestId('confirm-create')) + + await waitFor(() => { + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) + }) + }) + it('should close try app panel when close is clicked', () => { mockExploreData = { categories: ['Writing'], diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index 1261c0949c..f52fa44c4f 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -6,7 +6,7 @@ import type { TryAppSelection } from '@/types/try-app' import { useDebounceFn } from 'ahooks' import { useQueryState } from 'nuqs' import * as React from 'react' -import { useCallback, useMemo, useState } from 'react' +import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal' import Input from '@/app/components/base/input' @@ -26,6 +26,7 @@ import { fetchAppDetail } from '@/service/explore' import { useMembers } from '@/service/use-common' import { useExploreAppList } from '@/service/use-explore' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import TryApp from '../try-app' import s from './style.module.css' @@ -101,6 +102,7 @@ const Apps = ({ const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) const [currentTryApp, setCurrentTryApp] = useState(undefined) + const currentCreateAppModeRef = useRef(null) const isShowTryAppPanel = !!currentTryApp const hideTryAppPanel = useCallback(() => { setCurrentTryApp(undefined) @@ -112,8 +114,14 @@ const Apps = ({ setCurrApp(currentTryApp?.app || null) setIsShowCreateModal(true) }, [currentTryApp?.app]) + const trackCurrentCreateApp = useCallback(() => { + if (!currentCreateAppModeRef.current) + return - const onCreate: CreateAppModalProps['onConfirm'] = async ({ + trackCreateApp({ appMode: currentCreateAppModeRef.current }) + }, []) + + const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, icon, @@ -122,9 +130,10 @@ const Apps = ({ }) => { hideTryAppPanel() - const { export_data } = await fetchAppDetail( + const { export_data, mode } = await fetchAppDetail( currApp?.app.id as string, ) + currentCreateAppModeRef.current = mode const payload = { mode: DSLImportMode.YAML_CONTENT, yaml_content: export_data, @@ -136,19 +145,23 @@ const Apps = ({ } await handleImportDSL(payload, { onSuccess: () => { + trackCurrentCreateApp() setIsShowCreateModal(false) }, onPending: () => { setShowDSLConfirmModal(true) }, }) - } + }, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp]) const onConfirmDSL = useCallback(async () => { await handleImportDSLConfirm({ - onSuccess, + onSuccess: () => { + trackCurrentCreateApp() + onSuccess?.() + }, }) - }, [handleImportDSLConfirm, onSuccess]) + }, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp]) if (isLoading) { return ( diff --git a/web/app/signup/set-password/page.tsx b/web/app/signup/set-password/page.tsx index 72a25d6ac2..4a662e0623 100644 --- a/web/app/signup/set-password/page.tsx +++ b/web/app/signup/set-password/page.tsx @@ -11,6 +11,7 @@ import { validPassword } from '@/config' import { useRouter, useSearchParams } from '@/next/navigation' import { useMailRegister } from '@/service/use-common' import { cn } from '@/utils/classnames' +import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking' import { sendGAEvent } from '@/utils/gtag' const parseUtmInfo = () => { @@ -68,6 +69,7 @@ const ChangePasswordForm = () => { const { result } = res as MailRegisterResponse if (result === 'success') { const utmInfo = parseUtmInfo() + rememberCreateAppExternalAttribution({ utmInfo }) trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', { method: 'email', ...utmInfo, diff --git a/web/utils/__tests__/create-app-tracking.spec.ts b/web/utils/__tests__/create-app-tracking.spec.ts new file mode 100644 index 0000000000..855f54ebca --- /dev/null +++ b/web/utils/__tests__/create-app-tracking.spec.ts @@ -0,0 +1,134 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import * as amplitude from '@/app/components/base/amplitude' +import { AppModeEnum } from '@/types/app' +import { + buildCreateAppEventPayload, + extractExternalCreateAppAttribution, + rememberCreateAppExternalAttribution, + trackCreateApp, +} from '../create-app-tracking' + +describe('create-app-tracking', () => { + beforeEach(() => { + vi.restoreAllMocks() + vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {}) + window.sessionStorage.clear() + window.history.replaceState({}, '', '/apps') + }) + + describe('extractExternalCreateAppAttribution', () => { + it('should map campaign links to external attribution', () => { + const attribution = extractExternalCreateAppAttribution({ + searchParams: new URLSearchParams('utm_source=x&slug=how-to-build-rag-agent'), + }) + + expect(attribution).toEqual({ + utmSource: 'twitter/x', + utmCampaign: 'how-to-build-rag-agent', + }) + }) + + it('should map newsletter and blog sources to blog', () => { + expect(extractExternalCreateAppAttribution({ + searchParams: new URLSearchParams('utm_source=newsletter'), + })).toEqual({ utmSource: 'blog' }) + + expect(extractExternalCreateAppAttribution({ + utmInfo: { utm_source: 'dify_blog', slug: 'launch-week' }, + })).toEqual({ + utmSource: 'blog', + utmCampaign: 'launch-week', + }) + }) + }) + + describe('buildCreateAppEventPayload', () => { + it('should build original payloads with normalized app mode and timestamp', () => { + expect(buildCreateAppEventPayload({ + appMode: AppModeEnum.ADVANCED_CHAT, + }, null, new Date(2026, 3, 13, 14, 5, 9))).toEqual({ + source: 'original', + app_mode: 'chatflow', + time: '04-13-14:05:09', + }) + }) + + it('should map agent mode into the canonical app mode bucket', () => { + expect(buildCreateAppEventPayload({ + appMode: AppModeEnum.AGENT_CHAT, + }, null, new Date(2026, 3, 13, 9, 8, 7))).toEqual({ + source: 'original', + app_mode: 'agent', + time: '04-13-09:08:07', + }) + }) + + it('should fold legacy non-agent modes into chatflow', () => { + expect(buildCreateAppEventPayload({ + appMode: AppModeEnum.CHAT, + }, null, new Date(2026, 3, 13, 8, 0, 1))).toEqual({ + source: 'original', + app_mode: 'chatflow', + time: '04-13-08:00:01', + }) + + expect(buildCreateAppEventPayload({ + appMode: AppModeEnum.COMPLETION, + }, null, new Date(2026, 3, 13, 8, 0, 2))).toEqual({ + source: 'original', + app_mode: 'chatflow', + time: '04-13-08:00:02', + }) + }) + + it('should map workflow mode into the workflow bucket', () => { + expect(buildCreateAppEventPayload({ + appMode: AppModeEnum.WORKFLOW, + }, null, new Date(2026, 3, 13, 7, 6, 5))).toEqual({ + source: 'original', + app_mode: 'workflow', + time: '04-13-07:06:05', + }) + }) + + it('should prefer external attribution when present', () => { + expect(buildCreateAppEventPayload( + { + appMode: AppModeEnum.WORKFLOW, + }, + { + utmSource: 'linkedin', + utmCampaign: 'agent-launch', + }, + )).toEqual({ + source: 'external', + utm_source: 'linkedin', + utm_campaign: 'agent-launch', + }) + }) + }) + + describe('trackCreateApp', () => { + it('should track remembered external attribution once before falling back to internal source', () => { + rememberCreateAppExternalAttribution({ + searchParams: new URLSearchParams('utm_source=newsletter&slug=how-to-build-rag-agent'), + }) + + trackCreateApp({ appMode: AppModeEnum.WORKFLOW }) + + expect(amplitude.trackEvent).toHaveBeenNthCalledWith(1, 'create_app', { + source: 'external', + utm_source: 'blog', + utm_campaign: 'how-to-build-rag-agent', + }) + + trackCreateApp({ appMode: AppModeEnum.WORKFLOW }) + + expect(amplitude.trackEvent).toHaveBeenNthCalledWith(2, 'create_app', { + source: 'original', + app_mode: 'workflow', + time: expect.stringMatching(/^\d{2}-\d{2}-\d{2}:\d{2}:\d{2}$/), + }) + }) + }) +}) diff --git a/web/utils/create-app-tracking.ts b/web/utils/create-app-tracking.ts new file mode 100644 index 0000000000..8be63511bb --- /dev/null +++ b/web/utils/create-app-tracking.ts @@ -0,0 +1,187 @@ +import Cookies from 'js-cookie' +import { trackEvent } from '@/app/components/base/amplitude' +import { AppModeEnum } from '@/types/app' + +const CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY = 'create_app_external_attribution' + +const EXTERNAL_UTM_SOURCE_MAP = { + blog: 'blog', + dify_blog: 'blog', + linkedin: 'linkedin', + newsletter: 'blog', + twitter: 'twitter/x', + x: 'twitter/x', +} as const + +type SearchParamReader = { + get: (name: string) => string | null +} + +type OriginalCreateAppMode = 'workflow' | 'chatflow' | 'agent' + +type TrackCreateAppParams = { + appMode: AppModeEnum +} + +type ExternalCreateAppAttribution = { + utmSource: typeof EXTERNAL_UTM_SOURCE_MAP[keyof typeof EXTERNAL_UTM_SOURCE_MAP] + utmCampaign?: string +} + +const normalizeString = (value?: string | null) => { + const trimmed = value?.trim() + return trimmed || undefined +} + +const getObjectStringValue = (value: unknown) => { + return typeof value === 'string' ? normalizeString(value) : undefined +} + +const getSearchParamValue = (searchParams?: SearchParamReader | null, key?: string) => { + if (!searchParams || !key) + return undefined + return normalizeString(searchParams.get(key)) +} + +const parseJSONRecord = (value?: string | null): Record | null => { + if (!value) + return null + + try { + const parsed = JSON.parse(value) + return parsed && typeof parsed === 'object' ? parsed as Record : null + } + catch { + return null + } +} + +const getCookieUtmInfo = () => { + return parseJSONRecord(Cookies.get('utm_info')) +} + +const mapExternalUtmSource = (value?: string) => { + if (!value) + return undefined + + const normalized = value.toLowerCase() + return EXTERNAL_UTM_SOURCE_MAP[normalized as keyof typeof EXTERNAL_UTM_SOURCE_MAP] +} + +const padTimeValue = (value: number) => String(value).padStart(2, '0') + +const formatCreateAppTime = (date: Date) => { + return `${padTimeValue(date.getMonth() + 1)}-${padTimeValue(date.getDate())}-${padTimeValue(date.getHours())}:${padTimeValue(date.getMinutes())}:${padTimeValue(date.getSeconds())}` +} + +const mapOriginalCreateAppMode = (appMode: AppModeEnum): OriginalCreateAppMode => { + if (appMode === AppModeEnum.WORKFLOW) + return 'workflow' + + if (appMode === AppModeEnum.AGENT_CHAT) + return 'agent' + + return 'chatflow' +} + +export const extractExternalCreateAppAttribution = ({ + searchParams, + utmInfo, +}: { + searchParams?: SearchParamReader | null + utmInfo?: Record | null +}) => { + const rawSource = getSearchParamValue(searchParams, 'utm_source') ?? getObjectStringValue(utmInfo?.utm_source) + const mappedSource = mapExternalUtmSource(rawSource) + + if (!mappedSource) + return null + + const utmCampaign = getSearchParamValue(searchParams, 'slug') + ?? getSearchParamValue(searchParams, 'utm_campaign') + ?? getObjectStringValue(utmInfo?.slug) + ?? getObjectStringValue(utmInfo?.utm_campaign) + + return { + utmSource: mappedSource, + ...(utmCampaign ? { utmCampaign } : {}), + } satisfies ExternalCreateAppAttribution +} + +const readRememberedExternalCreateAppAttribution = (): ExternalCreateAppAttribution | null => { + if (typeof window === 'undefined') + return null + + return parseJSONRecord(window.sessionStorage.getItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY)) as ExternalCreateAppAttribution | null +} + +const writeRememberedExternalCreateAppAttribution = (attribution: ExternalCreateAppAttribution) => { + if (typeof window === 'undefined') + return + + window.sessionStorage.setItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY, JSON.stringify(attribution)) +} + +const clearRememberedExternalCreateAppAttribution = () => { + if (typeof window === 'undefined') + return + + window.sessionStorage.removeItem(CREATE_APP_EXTERNAL_ATTRIBUTION_STORAGE_KEY) +} + +export const rememberCreateAppExternalAttribution = ({ + searchParams, + utmInfo, +}: { + searchParams?: SearchParamReader | null + utmInfo?: Record | null +} = {}) => { + const attribution = extractExternalCreateAppAttribution({ + searchParams, + utmInfo: utmInfo ?? getCookieUtmInfo(), + }) + + if (attribution) + writeRememberedExternalCreateAppAttribution(attribution) + + return attribution +} + +const resolveCurrentExternalCreateAppAttribution = () => { + if (typeof window === 'undefined') + return null + + return rememberCreateAppExternalAttribution({ + searchParams: new URLSearchParams(window.location.search), + }) ?? readRememberedExternalCreateAppAttribution() +} + +export const buildCreateAppEventPayload = ( + params: TrackCreateAppParams, + externalAttribution?: ExternalCreateAppAttribution | null, + currentTime = new Date(), +) => { + if (externalAttribution) { + return { + source: 'external', + utm_source: externalAttribution.utmSource, + ...(externalAttribution.utmCampaign ? { utm_campaign: externalAttribution.utmCampaign } : {}), + } satisfies Record + } + + return { + source: 'original', + app_mode: mapOriginalCreateAppMode(params.appMode), + time: formatCreateAppTime(currentTime), + } satisfies Record +} + +export const trackCreateApp = (params: TrackCreateAppParams) => { + const externalAttribution = resolveCurrentExternalCreateAppAttribution() + const payload = buildCreateAppEventPayload(params, externalAttribution) + + if (externalAttribution) + clearRememberedExternalCreateAppAttribution() + + trackEvent('create_app', payload) +}