mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
feat: evaluation (#35688)
Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: hj24 <mambahj24@gmail.com> Co-authored-by: hj24 <huangjian@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
parent
a0d8e84667
commit
5402132525
@ -29,6 +29,7 @@ from graphon.file import helpers as file_helpers
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Dataset
|
||||
from models.evaluation import EvaluationTargetType
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
@ -48,8 +49,10 @@ logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Valid evaluation target types
|
||||
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||
EVALUATE_TARGET_TYPES = {
|
||||
EvaluationTargetType.APPS.value,
|
||||
EvaluationTargetType.SNIPPETS.value,
|
||||
}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
@ -187,7 +190,7 @@ evaluation_default_metrics_response_model = console_ns.model(
|
||||
|
||||
def get_evaluation_target(view_func: Callable[P, R]):
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||
Decorator to resolve polymorphic evaluation target (apps or snippets).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
@ -209,20 +212,16 @@ def get_evaluation_target(view_func: Callable[P, R]):
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet, Dataset] | None = None
|
||||
target: Union[App, CustomizedSnippet] | None = None
|
||||
|
||||
if target_type == "app":
|
||||
if target_type == EvaluationTargetType.APPS.value:
|
||||
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
elif target_type == "snippets":
|
||||
elif target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
elif target_type == "knowledge":
|
||||
target = (db.session.query(Dataset)
|
||||
.where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id)
|
||||
.first())
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
@ -681,7 +680,7 @@ class EvaluationVersionApi(Resource):
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
graph = {}
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
@ -791,8 +790,10 @@ class EvaluationWorkflowAssociatedTargetsApi(Resource):
|
||||
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
|
||||
|
||||
app_names: dict[str, str] = {}
|
||||
if "app" in target_ids_by_type:
|
||||
apps = session.scalars(select(App).where(App.id.in_(target_ids_by_type["app"]))).all()
|
||||
if EvaluationTargetType.APPS.value in target_ids_by_type:
|
||||
apps = session.scalars(
|
||||
select(App).where(App.id.in_(target_ids_by_type[EvaluationTargetType.APPS.value]))
|
||||
).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
|
||||
snippet_names: dict[str, str] = {}
|
||||
@ -812,9 +813,9 @@ class EvaluationWorkflowAssociatedTargetsApi(Resource):
|
||||
items = []
|
||||
for cfg in configs:
|
||||
name = ""
|
||||
if cfg.target_type == "app":
|
||||
if cfg.target_type == EvaluationTargetType.APPS.value:
|
||||
name = app_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "snippets":
|
||||
elif cfg.target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
name = snippet_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "knowledge_base":
|
||||
name = dataset_names.get(cfg.target_id, "")
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.snippet_start import get_compatible_start_aliases
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -116,7 +117,15 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
)
|
||||
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
|
||||
add_node_inputs_to_pool(
|
||||
variable_pool,
|
||||
node_id=root_node_id,
|
||||
inputs=inputs,
|
||||
aliases=get_compatible_start_aliases(
|
||||
workflow_kind=getattr(self._workflow, "kind_or_standard", None),
|
||||
root_node_id=root_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph = self._init_graph(
|
||||
|
||||
21
api/core/workflow/snippet_start.py
Normal file
21
api/core/workflow/snippet_start.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Shared snippet virtual Start-node identifiers and compatibility helpers.
|
||||
|
||||
Snippet workflows do not persist a real canvas Start node, so the backend
|
||||
injects one at runtime. Existing workflow references commonly use the public
|
||||
selector shape ``#start.<var>#``; keep that contract stable by treating the
|
||||
runtime-only snippet Start node as compatible with the legacy ``start`` id.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
LEGACY_START_NODE_ID = "start"
|
||||
SNIPPET_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
|
||||
|
||||
|
||||
def get_compatible_start_aliases(*, workflow_kind: str | None, root_node_id: str | None) -> tuple[str, ...]:
|
||||
"""Return additional selector ids that should mirror snippet Start inputs."""
|
||||
if workflow_kind == "snippet" and root_node_id == SNIPPET_VIRTUAL_START_NODE_ID:
|
||||
return (LEGACY_START_NODE_ID,)
|
||||
|
||||
return ()
|
||||
@ -10,6 +10,19 @@ def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Varia
|
||||
variable_pool.add(variable.selector, variable)
|
||||
|
||||
|
||||
def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None:
|
||||
for key, value in inputs.items():
|
||||
variable_pool.add((node_id, key), value)
|
||||
def add_node_inputs_to_pool(
|
||||
variable_pool: VariablePool,
|
||||
*,
|
||||
node_id: str,
|
||||
inputs: Mapping[str, Any],
|
||||
aliases: Sequence[str] = (),
|
||||
) -> None:
|
||||
"""Store node inputs under the primary node id and any compatible aliases."""
|
||||
node_ids: list[str] = [node_id]
|
||||
for alias in aliases:
|
||||
if alias not in node_ids:
|
||||
node_ids.append(alias)
|
||||
|
||||
for current_node_id in node_ids:
|
||||
for key, value in inputs.items():
|
||||
variable_pool.add((current_node_id, key), value)
|
||||
|
||||
@ -24,7 +24,7 @@ class EvaluationRunStatus(StrEnum):
|
||||
|
||||
|
||||
class EvaluationTargetType(StrEnum):
|
||||
APP = "app"
|
||||
APPS = "apps"
|
||||
SNIPPETS = "snippets"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ from models.evaluation import (
|
||||
EvaluationRun,
|
||||
EvaluationRunItem,
|
||||
EvaluationRunStatus,
|
||||
EvaluationTargetType,
|
||||
)
|
||||
from models.model import App, AppMode
|
||||
from models.snippet import CustomizedSnippet
|
||||
@ -70,18 +71,18 @@ class EvaluationService:
|
||||
The first column is index, followed by input parameter columns.
|
||||
|
||||
:param target: App or CustomizedSnippet instance
|
||||
:param target_type: Target type string ("app" or "snippet")
|
||||
:param target_type: Target type string ("apps" or "snippets")
|
||||
:return: Tuple of (xlsx_content_bytes, filename)
|
||||
:raises ValueError: If target type is not supported or app mode is excluded
|
||||
"""
|
||||
# Validate target type
|
||||
if target_type == "app":
|
||||
if target_type == EvaluationTargetType.APPS.value:
|
||||
if not isinstance(target, App):
|
||||
raise ValueError("Invalid target: expected App instance")
|
||||
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
|
||||
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
|
||||
input_fields = cls._get_app_input_fields(target)
|
||||
elif target_type == "snippet":
|
||||
elif target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
if not isinstance(target, CustomizedSnippet):
|
||||
raise ValueError("Invalid target: expected CustomizedSnippet instance")
|
||||
input_fields = cls._get_snippet_input_fields(target)
|
||||
@ -581,7 +582,7 @@ class EvaluationService:
|
||||
"""Return node info grouped by metric (or all nodes when *metrics* is empty).
|
||||
|
||||
:param target: App or CustomizedSnippet instance.
|
||||
:param target_type: ``"app"`` or ``"snippets"``.
|
||||
:param target_type: ``"apps"`` or ``"snippets"``.
|
||||
:param metrics: Optional list of metric names to filter by.
|
||||
When *None* or empty, returns ``{"all": [<every node>]}``.
|
||||
:returns: ``{metric_name: [NodeInfo dict, ...]}`` or
|
||||
@ -607,9 +608,9 @@ class EvaluationService:
|
||||
target_type: str,
|
||||
) -> Workflow | None:
|
||||
"""Resolve only the published workflow for the target (no draft fallback)."""
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
|
||||
return SnippetService().get_published_workflow(snippet=target)
|
||||
if target_type == "app" and isinstance(target, App):
|
||||
if target_type == EvaluationTargetType.APPS.value and isinstance(target, App):
|
||||
return WorkflowService().get_published_workflow(app_model=target)
|
||||
return None
|
||||
|
||||
@ -620,13 +621,13 @@ class EvaluationService:
|
||||
target_type: str,
|
||||
) -> Workflow | None:
|
||||
"""Resolve the *published* (preferred) or *draft* workflow for the target."""
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=target)
|
||||
if not workflow:
|
||||
workflow = snippet_service.get_draft_workflow(snippet=target)
|
||||
return workflow
|
||||
elif target_type == "app" and isinstance(target, App):
|
||||
elif target_type == EvaluationTargetType.APPS.value and isinstance(target, App):
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=target)
|
||||
if not workflow:
|
||||
@ -663,7 +664,7 @@ class EvaluationService:
|
||||
"""Execute the evaluation target for every test-data item in parallel.
|
||||
|
||||
:param tenant_id: Workspace / tenant ID.
|
||||
:param target_type: ``"app"`` or ``"snippet"``.
|
||||
:param target_type: ``"apps"`` or ``"snippets"``.
|
||||
:param target_id: ID of the App or CustomizedSnippet.
|
||||
:param input_list: All test-data items parsed from the dataset.
|
||||
:param max_workers: Maximum number of parallel worker threads.
|
||||
@ -745,8 +746,8 @@ class EvaluationService:
|
||||
Dispatches to the appropriate execution service based on
|
||||
``target_type``:
|
||||
|
||||
* ``"snippet"`` → :meth:`SnippetGenerateService.run_published`
|
||||
* ``"app"`` → :meth:`WorkflowAppGenerator().generate` (blocking mode)
|
||||
* ``"snippets"`` → :meth:`SnippetGenerateService.run_published`
|
||||
* ``"apps"`` → :meth:`WorkflowAppGenerator().generate` (blocking mode)
|
||||
|
||||
:returns: The blocking response mapping from the workflow engine.
|
||||
:raises ValueError: If the target is not found or not published.
|
||||
@ -755,7 +756,7 @@ class EvaluationService:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.evaluation.runners import get_service_account_for_app, get_service_account_for_snippet
|
||||
|
||||
if target_type == "snippet":
|
||||
if target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.query(CustomizedSnippet).filter_by(id=target_id).first()
|
||||
@ -771,7 +772,7 @@ class EvaluationService:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
else:
|
||||
# target_type == "app"
|
||||
# target_type == "apps"
|
||||
app = session.query(App).filter_by(id=target_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App {target_id} not found")
|
||||
|
||||
@ -28,6 +28,7 @@ from sqlalchemy.orm import make_transient
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.snippet_start import SNIPPET_VIRTUAL_START_NODE_ID
|
||||
from factories import file_factory
|
||||
from graphon.file.models import File
|
||||
from models import Account
|
||||
@ -78,7 +79,51 @@ class SnippetGenerateService:
|
||||
"""
|
||||
|
||||
# Specific ID for the injected virtual Start node so it can be recognised
|
||||
_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
|
||||
_VIRTUAL_START_NODE_ID = SNIPPET_VIRTUAL_START_NODE_ID
|
||||
|
||||
@classmethod
|
||||
def _is_virtual_start_event(cls, message: Mapping[str, Any] | str) -> bool:
|
||||
"""
|
||||
Return True when *message* is a snippet-only virtual Start node event.
|
||||
|
||||
The virtual Start node is injected purely for snippet execution and is
|
||||
not part of the persisted draft graph. Filter its node lifecycle events
|
||||
out of the SSE stream so the frontend only receives nodes that exist on
|
||||
the canvas.
|
||||
"""
|
||||
if not isinstance(message, Mapping):
|
||||
return False
|
||||
|
||||
if message.get("event") not in {"node_started", "node_finished"}:
|
||||
return False
|
||||
|
||||
data = message.get("data")
|
||||
if not isinstance(data, Mapping):
|
||||
return False
|
||||
|
||||
return data.get("node_id") == cls._VIRTUAL_START_NODE_ID
|
||||
|
||||
@classmethod
|
||||
def _filter_virtual_start_events(
|
||||
cls,
|
||||
response: Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None],
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||
"""
|
||||
Drop snippet virtual Start node lifecycle events from stream responses.
|
||||
|
||||
Blocking responses are returned unchanged because they never expose the
|
||||
injected node as a standalone event payload.
|
||||
"""
|
||||
if isinstance(response, Mapping):
|
||||
return response
|
||||
|
||||
def _stream() -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
for message in response:
|
||||
if cls._is_virtual_start_event(message):
|
||||
continue
|
||||
yield message
|
||||
|
||||
return _stream()
|
||||
|
||||
@classmethod
|
||||
def _is_virtual_start_event(cls, message: Mapping[str, Any] | str) -> bool:
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.snippet_start import SNIPPET_VIRTUAL_START_NODE_ID
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -163,3 +164,68 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
)
|
||||
|
||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||
|
||||
|
||||
def test_run_adds_legacy_start_alias_for_snippet_virtual_start() -> None:
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app"
|
||||
app_config.tenant_id = "tenant"
|
||||
app_config.workflow_id = "workflow"
|
||||
|
||||
app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity)
|
||||
app_generate_entity.app_config = app_config
|
||||
app_generate_entity.inputs = {"query": "123"}
|
||||
app_generate_entity.files = []
|
||||
app_generate_entity.user_id = "user"
|
||||
app_generate_entity.invoke_from = InvokeFrom.DEBUGGER
|
||||
app_generate_entity.workflow_execution_id = "execution-id"
|
||||
app_generate_entity.task_id = "task-id"
|
||||
app_generate_entity.call_depth = 0
|
||||
app_generate_entity.trace_manager = None
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.tenant_id = "tenant"
|
||||
workflow.app_id = "app"
|
||||
workflow.id = "workflow"
|
||||
workflow.type = "workflow"
|
||||
workflow.version = "draft"
|
||||
workflow.graph_dict = {"nodes": [], "edges": []}
|
||||
workflow.environment_variables = []
|
||||
workflow.kind_or_standard = "snippet"
|
||||
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(spec=AppQueueManager),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="system-user",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.graph_engine = MagicMock()
|
||||
mock_workflow_entry.graph_engine.layer = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([])
|
||||
|
||||
def _init_graph(**kwargs):
|
||||
variable_pool = kwargs["graph_runtime_state"].variable_pool
|
||||
virtual_start_query = variable_pool.get((SNIPPET_VIRTUAL_START_NODE_ID, "query"))
|
||||
legacy_start_query = variable_pool.get(("start", "query"))
|
||||
|
||||
assert virtual_start_query is not None
|
||||
assert virtual_start_query.value == "123"
|
||||
assert legacy_start_query is not None
|
||||
assert legacy_start_query.value == "123"
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow.app_runner.RedisChannel"),
|
||||
patch("core.app.apps.workflow.app_runner.redis_client"),
|
||||
patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry),
|
||||
patch("core.app.apps.workflow.app_runner.get_default_root_node_id", return_value=SNIPPET_VIRTUAL_START_NODE_ID),
|
||||
patch.object(runner, "_init_graph", side_effect=_init_graph),
|
||||
):
|
||||
runner.run()
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
from graphon.runtime import VariablePool
|
||||
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool
|
||||
|
||||
|
||||
def test_add_node_inputs_to_pool_writes_primary_and_alias_selectors() -> None:
|
||||
variable_pool = VariablePool()
|
||||
|
||||
add_node_inputs_to_pool(
|
||||
variable_pool,
|
||||
node_id="__snippet_virtual_start__",
|
||||
inputs={"query": "123"},
|
||||
aliases=("start",),
|
||||
)
|
||||
|
||||
virtual_start_query = variable_pool.get(("__snippet_virtual_start__", "query"))
|
||||
legacy_start_query = variable_pool.get(("start", "query"))
|
||||
|
||||
assert virtual_start_query is not None
|
||||
assert virtual_start_query.value == "123"
|
||||
assert legacy_start_query is not None
|
||||
assert legacy_start_query.value == "123"
|
||||
|
||||
|
||||
def test_add_node_inputs_to_pool_deduplicates_aliases() -> None:
|
||||
variable_pool = VariablePool()
|
||||
|
||||
add_node_inputs_to_pool(
|
||||
variable_pool,
|
||||
node_id="start",
|
||||
inputs={"query": "123"},
|
||||
aliases=("start",),
|
||||
)
|
||||
|
||||
start_query = variable_pool.get(("start", "query"))
|
||||
|
||||
assert start_query is not None
|
||||
assert start_query.value == "123"
|
||||
Loading…
Reference in New Issue
Block a user