From 7fc40e6c9ef7ec1caf03cd21aa40f52596b3954e Mon Sep 17 00:00:00 2001 From: Blackoutta <37723456+Blackoutta@users.noreply.github.com> Date: Mon, 11 May 2026 16:37:17 +0800 Subject: [PATCH 1/4] feat: improve phoenix workflow tracing (#35605) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- --- api/.env.example | 4 + api/configs/feature/__init__.py | 13 + api/core/app/apps/workflow/app_generator.py | 3 +- api/core/app/workflow/layers/persistence.py | 9 +- api/core/helper/trace_id_helper.py | 35 + api/core/ops/entities/trace_entity.py | 21 +- api/core/ops/exceptions.py | 22 + api/core/ops/ops_trace_manager.py | 24 +- api/core/tools/workflow_as_tool/tool.py | 32 +- api/core/workflow/node_runtime.py | 33 +- .../arize_phoenix_trace.py | 493 ++++++- .../test_arize_phoenix_trace.py | 1194 ++++++++++++++++- .../unit_tests/test_arize_phoenix_trace.py | 36 - api/tasks/ops_trace_task.py | 91 +- .../app/apps/test_workflow_app_generator.py | 71 + .../app/workflow/test_persistence_layer.py | 54 + .../core/helper/test_trace_id_helper.py | 97 +- .../core/tools/workflow_as_tool/test_tool.py | 136 ++ .../workflow/nodes/tool/test_tool_node.py | 25 +- .../nodes/tool/test_tool_node_runtime.py | 63 + .../core/workflow/test_node_runtime.py | 75 ++ .../unit_tests/tasks/test_ops_trace_task.py | 301 +++++ docker/.env.example | 1 + docker/envs/core-services/shared.env.example | 2 + 24 files changed, 2727 insertions(+), 108 deletions(-) create mode 100644 api/core/ops/exceptions.py delete mode 100644 api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py create mode 100644 api/tests/unit_tests/tasks/test_ops_trace_task.py diff --git a/api/.env.example b/api/.env.example index ba153e4c9c..40fed7403c 100644 --- a/api/.env.example +++ b/api/.env.example @@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30 CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BACKEND=redis +# Ops trace retry configuration +OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60 +OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5 + # Database configuration DB_TYPE=postgresql DB_USERNAME=postgres diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index e9bb34fa75..26b8ea670b 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings): ) +class OpsTraceConfig(BaseSettings): + OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field( + description="Maximum retry attempts for transient ops trace provider dispatch failures.", + default=60, + ) + + OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field( + description="Delay in seconds between transient ops trace provider dispatch retry attempts.", + default=5, + ) + + class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( description="Interval in days for Celery Beat scheduler execution, default to 1 day", @@ -1417,6 +1429,7 @@ class FeatureConfig( ModelLoadBalanceConfig, ModerationConfig, MultiModalTransferConfig, + OpsTraceConfig, PositionConfig, RagEtlConfig, RepositoryConfig, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index e811c2b2e0..43546d57f5 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -32,7 +32,7 @@ from core.app.entities.task_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory -from core.helper.trace_id_helper import extract_external_trace_id_from_args +from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository @@ -166,6 +166,7 @@ class WorkflowAppGenerator(BaseAppGenerator): extras = { **extract_external_trace_id_from_args(args), + **extract_parent_trace_context_from_args(args), } workflow_run_id = str(workflow_run_id or uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index d521304615..19152cebae 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -15,6 +15,7 @@ from datetime import datetime from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.helper.trace_id_helper import ParentTraceContext from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository @@ -403,8 +404,13 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) external_trace_id = None + parent_trace_context = None if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): - external_trace_id = self._application_generate_entity.extras.get("external_trace_id") + extras = self._application_generate_entity.extras + external_trace_id = extras.get("external_trace_id") + parent_trace_context = extras.get("parent_trace_context") + if isinstance(parent_trace_context, ParentTraceContext): + parent_trace_context = parent_trace_context.model_dump(exclude_none=True) trace_task = TraceTask( TraceTaskName.WORKFLOW_TRACE, @@ -412,6 +418,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id=conversation_id, user_id=self._trace_manager.user_id, external_trace_id=external_trace_id, + parent_trace_context=parent_trace_context, ) self._trace_manager.add_trace_task(trace_task) diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index e827859109..e4890c8d4d 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -3,6 +3,17 @@ import re from collections.abc import Mapping from typing import Any +from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError + + +class ParentTraceContext(BaseModel): + """Typed parent trace context propagated from an outer workflow tool node.""" + + parent_workflow_run_id: StrictStr + parent_node_execution_id: StrictStr | None = None + + model_config = ConfigDict(extra="forbid") + def is_valid_trace_id(trace_id: str) -> bool: """ @@ -61,6 +72,30 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]): return {} +def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]: + """ + Extract 'parent_trace_context' from args. + + Returns a dict suitable for use in extras when both parent identifiers exist. + Returns an empty dict if the context is missing or incomplete. + """ + parent_trace_context = args.get("parent_trace_context") + if isinstance(parent_trace_context, ParentTraceContext): + context = parent_trace_context + elif isinstance(parent_trace_context, Mapping): + try: + context = ParentTraceContext.model_validate(parent_trace_context) + except ValidationError: + return {} + else: + return {} + + if context.parent_node_execution_id is None: + return {} + + return {"parent_trace_context": context} + + def get_trace_id_from_otel_context() -> str | None: """ Retrieve the current trace ID from the active OpenTelemetry trace context. diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 45b2f635ba..98e87a0ceb 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -5,6 +5,8 @@ from typing import Any, Union from pydantic import BaseModel, ConfigDict, field_serializer, field_validator +from core.helper.trace_id_helper import ParentTraceContext + class BaseTraceInfo(BaseModel): message_id: str | None = None @@ -51,8 +53,8 @@ class BaseTraceInfo(BaseModel): def resolved_parent_context(self) -> tuple[str | None, str | None]: """Resolve cross-workflow parent linking from metadata. - Extracts typed parent IDs from the untyped ``parent_trace_context`` - metadata dict (set by tool_node when invoking nested workflows). + Extracts typed parent IDs from the ``parent_trace_context`` metadata + payload (set by tool_node when invoking nested workflows). Returns: (trace_correlation_override, parent_span_id_source) where @@ -60,13 +62,18 @@ class BaseTraceInfo(BaseModel): parent_span_id_source is the outer node_execution_id. """ parent_ctx = self.metadata.get("parent_trace_context") - if not isinstance(parent_ctx, dict): + if isinstance(parent_ctx, ParentTraceContext): + context = parent_ctx + elif isinstance(parent_ctx, Mapping): + try: + context = ParentTraceContext.model_validate(parent_ctx) + except ValueError: + return None, None + else: return None, None - trace_override = parent_ctx.get("parent_workflow_run_id") - parent_span = parent_ctx.get("parent_node_execution_id") return ( - trace_override if isinstance(trace_override, str) else None, - parent_span if isinstance(parent_span, str) else None, + context.parent_workflow_run_id, + context.parent_node_execution_id, ) @field_serializer("start_time", "end_time") diff --git a/api/core/ops/exceptions.py b/api/core/ops/exceptions.py new file mode 100644 index 0000000000..4551704687 --- /dev/null +++ b/api/core/ops/exceptions.py @@ -0,0 +1,22 @@ +"""Core exceptions shared by ops trace dispatchers and trace providers. + +Provider packages may raise these types to request generic task behavior, but +generic Celery tasks should not import provider-specific exception classes. +""" + + +class RetryableTraceDispatchError(RuntimeError): + """Base class for transient trace dispatch failures that Celery may retry.""" + + +class PendingTraceParentContextError(RetryableTraceDispatchError): + """Raised when a nested trace arrives before its parent span context is available.""" + + parent_node_execution_id: str + + def __init__(self, parent_node_execution_id: str) -> None: + self.parent_node_execution_id = parent_node_execution_id + super().__init__( + "Pending trace parent context for parent_node_execution_id=" + f"{parent_node_execution_id}. Retry after the parent span context is published." + ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index bae0016744..61fd0e5c1f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -16,6 +16,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token +from core.helper.trace_id_helper import ParentTraceContext from core.ops.entities.config_entity import ( OPS_FILE_PATH, BaseTracingConfig, @@ -52,6 +53,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | None: + if isinstance(parent_trace_context, ParentTraceContext): + return parent_trace_context.model_dump(exclude_none=True) + if isinstance(parent_trace_context, dict): + try: + return ParentTraceContext.model_validate(parent_trace_context).model_dump(exclude_none=True) + except ValueError: + return None + return None + + class _AppTracingConfig(TypedDict, total=False): enabled: bool tracing_provider: str | None @@ -857,8 +869,9 @@ class TraceTask: } parent_trace_context = self.kwargs.get("parent_trace_context") - if parent_trace_context: - metadata["parent_trace_context"] = parent_trace_context + dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context) + if dumped_parent_trace_context: + metadata["parent_trace_context"] = dumped_parent_trace_context workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, @@ -1371,13 +1384,14 @@ class TraceTask: } parent_trace_context = node_data.get("parent_trace_context") - if parent_trace_context: - metadata["parent_trace_context"] = parent_trace_context + dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context) + if dumped_parent_trace_context: + metadata["parent_trace_context"] = dumped_parent_trace_context message_id: str | None = None conversation_id = node_data.get("conversation_id") workflow_execution_id = node_data.get("workflow_execution_id") - if conversation_id and workflow_execution_id and not parent_trace_context: + if conversation_id and workflow_execution_id and not dumped_parent_trace_context: with Session(db.engine) as session: msg_id = session.scalar( select(Message.id).where( diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index cd8c6352b5..3fbd456fe5 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -9,6 +9,7 @@ from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory +from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -36,6 +37,8 @@ class WorkflowTool(Tool): Workflow tool. """ + _parent_trace_context: ParentTraceContext | None + def __init__( self, workflow_app_id: str, @@ -54,6 +57,7 @@ class WorkflowTool(Tool): self.workflow_call_depth = workflow_call_depth self.label = label self._latest_usage = LLMUsage.empty_usage() + self._parent_trace_context = None super().__init__(entity=entity, runtime=runtime) @@ -94,11 +98,17 @@ class WorkflowTool(Tool): self._latest_usage = LLMUsage.empty_usage() + generator_args: dict[str, Any] = {"inputs": tool_parameters, "files": files} + if self._parent_trace_context: + generator_args.update( + extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context}) + ) + result = generator.generate( app_model=app, workflow=workflow, user=user, - args={"inputs": tool_parameters, "files": files}, + args=generator_args, invoke_from=self.runtime.invoke_from, streaming=False, call_depth=self.workflow_call_depth + 1, @@ -194,7 +204,7 @@ class WorkflowTool(Tool): :return: the new tool """ - return self.__class__( + forked = self.__class__( entity=self.entity.model_copy(), runtime=runtime, workflow_app_id=self.workflow_app_id, @@ -204,6 +214,24 @@ class WorkflowTool(Tool): version=self.version, label=self.label, ) + forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None + return forked + + def set_parent_trace_context( + self, + *, + parent_workflow_run_id: str, + parent_node_execution_id: str, + ) -> None: + """Attach outer workflow trace context without exposing it as tool input.""" + self._parent_trace_context = ParentTraceContext( + parent_workflow_run_id=parent_workflow_run_id, + parent_node_execution_id=parent_node_execution_id, + ) + + def clear_parent_trace_context(self) -> None: + """Remove parent trace context before invoking this tool outside a nested workflow.""" + self._parent_trace_context = None def _resolve_user(self, user_id: str) -> Account | EndUser | None: """ diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index d687d9a6e0..db7d78bf45 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.file_access import DatabaseFileAccessController from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.helper.trace_id_helper import ParentTraceContext from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelInstance @@ -358,6 +359,7 @@ class _WorkflowToolRuntimeBinding: tool: Tool conversation_id: str | None = None + parent_trace_context: ParentTraceContext | None = None class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): @@ -398,7 +400,25 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): conversation_id = ( None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) ) - return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + parent_trace_context: ParentTraceContext | None = None + if self._is_workflow_tool_provider(node_data): + outer_workflow_run_id = ( + None + if variable_pool is None + else get_system_text(variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + ) + if isinstance(outer_workflow_run_id, str) and isinstance(node_execution_id, str): + parent_trace_context = ParentTraceContext( + parent_workflow_run_id=outer_workflow_run_id, + parent_node_execution_id=node_execution_id, + ) + return ToolRuntimeHandle( + raw=_WorkflowToolRuntimeBinding( + tool=tool_runtime, + conversation_id=conversation_id, + parent_trace_context=parent_trace_context, + ) + ) def get_runtime_parameters( self, @@ -422,6 +442,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): runtime_binding = self._binding_from_handle(tool_runtime) tool = runtime_binding.tool callback = DifyWorkflowCallbackHandler() + if runtime_binding.parent_trace_context and hasattr(tool, "set_parent_trace_context"): + tool.set_parent_trace_context( + parent_workflow_run_id=runtime_binding.parent_trace_context.parent_workflow_run_id, + parent_node_execution_id=runtime_binding.parent_trace_context.parent_node_execution_id, + ) + elif hasattr(tool, "clear_parent_trace_context"): + tool.clear_parent_trace_context() try: messages = ToolEngine.generic_invoke( @@ -514,6 +541,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): credential_id=node_data.credential_id, ) + @staticmethod + def _is_workflow_tool_provider(node_data: ToolNodeData) -> bool: + return node_data.provider_type.value == CoreToolProviderType.WORKFLOW.value + def _adapt_messages( self, messages: Generator[CoreToolInvokeMessage, None, None], diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 96df49ed0e..a0d150e1b6 100644 --- a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -1,9 +1,11 @@ import json import logging import os +import re import traceback +from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import Any, Union, cast +from typing import Any, Protocol, Union, cast from urllib.parse import urlparse from openinference.semconv.trace import ( @@ -19,7 +21,7 @@ from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.semconv.attributes import exception_attributes -from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span +from opentelemetry.trace import Span, Status, StatusCode, get_current_span, set_span_in_context, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker @@ -36,16 +38,106 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.ops.exceptions import PendingTraceParentContextError from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db +from extensions.ext_redis import redis_client from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) +# This parent-span carrier store is intentionally Phoenix-local for the current +# nested workflow tracing feature. If other trace providers need the same +# cross-task parent restoration behavior, move the storage and retry signaling +# behind a core trace coordination interface instead of duplicating it here. +_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS = 300 +_TRACEPARENT_PATTERN = re.compile( + r"^(?P[0-9a-f]{2})-(?P[0-9a-f]{32})-(?P[0-9a-f]{16})-(?P[0-9a-f]{2})$" +) + + +def _phoenix_parent_span_redis_key(parent_node_execution_id: str) -> str: + """Build the Redis key that stores a restorable Phoenix parent span carrier.""" + return f"trace:phoenix:parent_span:{parent_node_execution_id}" + + +def _publish_parent_span_context(parent_node_execution_id: str, carrier: Mapping[str, str]) -> None: + """Persist a tracecontext carrier so nested workflow spans can restore the tool span parent.""" + redis_client.setex( + _phoenix_parent_span_redis_key(parent_node_execution_id), + _PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS, + safe_json_dumps(dict(carrier)), + ) + + +def _resolve_published_parent_span_context(parent_node_execution_id: str) -> dict[str, str]: + """Load a previously published tool-span carrier for nested workflow parenting.""" + raw_carrier = redis_client.get(_phoenix_parent_span_redis_key(parent_node_execution_id)) + if raw_carrier is None: + raise PendingTraceParentContextError(parent_node_execution_id) + + if isinstance(raw_carrier, bytes): + raw_carrier = raw_carrier.decode("utf-8") + + carrier = json.loads(raw_carrier) + if not isinstance(carrier, dict): + raise ValueError( + "Phoenix parent span context must be stored as a JSON object: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + normalized_carrier = {str(key): str(value) for key, value in carrier.items()} + if not normalized_carrier: + raise ValueError( + f"Phoenix parent span context payload is empty: parent_node_execution_id={parent_node_execution_id}" + ) + + traceparent = normalized_carrier.get("traceparent") + if not isinstance(traceparent, str): + raise ValueError( + "Phoenix parent span context payload is missing traceparent: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + traceparent_match = _TRACEPARENT_PATTERN.fullmatch(traceparent) + if traceparent_match is None: + raise ValueError( + "Phoenix parent span context payload has invalid traceparent format: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + if traceparent_match.group("version") == "ff": + raise ValueError( + "Phoenix parent span context payload has unsupported traceparent version: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + if traceparent_match.group("trace_id") == "0" * 32: + raise ValueError( + "Phoenix parent span context payload has zero trace_id in traceparent: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + if traceparent_match.group("span_id") == "0" * 16: + raise ValueError( + "Phoenix parent span context payload has zero span_id in traceparent: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + extracted_context = TraceContextTextMapPropagator().extract(carrier=normalized_carrier) + extracted_span_context = get_current_span(extracted_context).get_span_context() + if not extracted_span_context.is_valid or not extracted_span_context.is_remote: + raise ValueError( + "Phoenix parent span context payload could not be restored into a valid parent span: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + return normalized_carrier + def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]: """Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix.""" @@ -177,6 +269,246 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN) +def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str: + """Resolve the workflow session ID for Phoenix workflow spans.""" + if trace_info.conversation_id: + return trace_info.conversation_id + + parent_workflow_run_id, _ = _resolve_workflow_parent_context(trace_info) + if parent_workflow_run_id: + return parent_workflow_run_id + + return trace_info.workflow_run_id + + +def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]: + """Expose the typed parent context already resolved on the trace info.""" + return trace_info.resolved_parent_context + + +def _resolve_workflow_root_trace_id(trace_info: WorkflowTraceInfo) -> str: + """Resolve the canonical root trace ID for Phoenix workflow spans.""" + trace_correlation_override, _ = _resolve_workflow_parent_context(trace_info) + return trace_correlation_override or trace_info.resolved_trace_id or trace_info.workflow_run_id + + +class _NodeExecutionIdentityLike(Protocol): + @property + def node_execution_id(self) -> str | None: ... + + @property + def node_id(self) -> str: ... + + @property + def predecessor_node_id(self) -> str | None: ... + + +class _NodeExecutionLike(_NodeExecutionIdentityLike, Protocol): + @property + def id(self) -> str: ... + + @property + def node_type(self) -> str: ... + + @property + def title(self) -> str | None: ... + + @property + def inputs(self) -> Mapping[str, Any] | None: ... + + @property + def process_data(self) -> Mapping[str, Any] | None: ... + + @property + def outputs(self) -> Mapping[str, Any] | None: ... + + @property + def status(self) -> WorkflowNodeExecutionStatus: ... + + @property + def error(self) -> str | None: ... + + @property + def elapsed_time(self) -> float | None: ... + + @property + def metadata(self) -> Mapping[Any, Any] | None: ... + + @property + def created_at(self) -> datetime | None: ... + + +_PHOENIX_STRUCTURED_NODE_TYPES = frozenset({"start", "end", "loop", "iteration"}) + + +def _resolve_workflow_span_name(trace_info: WorkflowTraceInfo) -> str: + """Resolve the Phoenix workflow span display name.""" + workflow_run_id = trace_info.workflow_run_id.strip() if trace_info.workflow_run_id else "" + if workflow_run_id: + return f"{TraceTaskName.WORKFLOW_TRACE.value}_{workflow_run_id}" + return TraceTaskName.WORKFLOW_TRACE.value + + +def _build_node_title_by_id(trace_info: WorkflowTraceInfo) -> dict[str, str]: + """Build an authoritative node-title index from the persisted workflow graph.""" + workflow_data = trace_info.workflow_data + workflow_graph = getattr(workflow_data, "graph_dict", None) + if not isinstance(workflow_graph, Mapping): + workflow_graph = workflow_data.get("graph") if isinstance(workflow_data, Mapping) else None + if not isinstance(workflow_graph, Mapping): + return {} + + graph_nodes = workflow_graph.get("nodes") + if not isinstance(graph_nodes, Sequence): + return {} + + node_title_by_id: dict[str, str] = {} + for graph_node in graph_nodes: + if not isinstance(graph_node, Mapping): + continue + node_id = graph_node.get("id") + node_data = graph_node.get("data") + if not isinstance(node_id, str) or not isinstance(node_data, Mapping): + continue + node_title = node_data.get("title") + if isinstance(node_title, str) and node_title.strip(): + node_title_by_id[node_id] = node_title.strip() + + return node_title_by_id + + +def _resolve_workflow_node_span_name( + node_execution: _NodeExecutionLike, + node_title_by_id: Mapping[str, str] | None = None, +) -> str: + """Resolve the Phoenix workflow node span display name.""" + node_type = str(node_execution.node_type or "") + graph_node_title = None + if node_title_by_id is not None and isinstance(node_execution.node_id, str): + graph_node_title = node_title_by_id.get(node_execution.node_id) + + node_title = graph_node_title or (node_execution.title.strip() if isinstance(node_execution.title, str) else "") + if node_title: + return f"{node_type}_{node_title}" + return node_type + + +def _get_node_execution_id(node_execution: _NodeExecutionIdentityLike) -> str: + """Return the stable execution identifier for a workflow node execution.""" + return str(getattr(node_execution, "id", None) or node_execution.node_execution_id) + + +def _build_execution_id_by_node_id(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]: + """Index unique workflow graph node ids by execution id. + + This Phoenix-local hierarchy reconstruction intentionally drops ambiguous + node ids instead of guessing based on repository order. That keeps parent + selection deterministic until upstream tracing exposes explicit parent span + data for repeated executions. + """ + execution_id_by_node_id: dict[str, str] = {} + ambiguous_node_ids: set[str] = set() + + for node_execution in node_executions: + node_id = node_execution.node_id + if not isinstance(node_id, str): + continue + execution_id = _get_node_execution_id(node_execution) + + if node_id in ambiguous_node_ids: + continue + + existing_execution_id = execution_id_by_node_id.get(node_id) + if existing_execution_id is None: + execution_id_by_node_id[node_id] = execution_id + continue + + if existing_execution_id != execution_id: + ambiguous_node_ids.add(node_id) + execution_id_by_node_id.pop(node_id, None) + + return execution_id_by_node_id + + +def _build_graph_parent_index(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]: + """Build an execution-id parent index from predecessor node ids.""" + execution_id_by_node_id = _build_execution_id_by_node_id(node_executions) + graph_parent_index: dict[str, str] = {} + + for node_execution in node_executions: + predecessor_node_id = node_execution.predecessor_node_id + if not isinstance(predecessor_node_id, str): + continue + + predecessor_execution_id = execution_id_by_node_id.get(predecessor_node_id) + if predecessor_execution_id is not None: + execution_id = _get_node_execution_id(node_execution) + graph_parent_index[execution_id] = predecessor_execution_id + + return graph_parent_index + + +def _resolve_structured_parent_execution_id( + node_execution: object, execution_id_by_node_id: Mapping[str, str] +) -> str | None: + """Resolve Phoenix-local structured parents from loop/iteration node ids. + + Any execution carrying ``iteration_id`` or ``loop_id`` belongs to an + enclosing structured node. When predecessor node ids are ambiguous because + the graph node repeats inside that structure, Phoenix can still keep the + child span under the enclosing loop/iteration span without relying on + execution-order heuristics. + """ + execution_metadata = getattr(node_execution, "execution_metadata_dict", None) + if not isinstance(execution_metadata, Mapping): + execution_metadata = getattr(node_execution, "metadata", None) + if not isinstance(execution_metadata, Mapping): + execution_metadata = {} + + for enclosing_node_id in ( + getattr(node_execution, "iteration_id", None), + getattr(node_execution, "loop_id", None), + execution_metadata.get("iteration_id"), + execution_metadata.get("loop_id"), + ): + if not isinstance(enclosing_node_id, str): + continue + + enclosing_execution_id = execution_id_by_node_id.get(enclosing_node_id) + if enclosing_execution_id is not None: + return enclosing_execution_id + + return None + + +def _resolve_node_parent( + execution_id: str, + predecessor_execution_id: str | None, + structured_parent_execution_id: str | None, + span_by_execution_id: Mapping[str, Span], + graph_parent_index: Mapping[str, str], + workflow_span: Span, +) -> Span: + """Resolve the parent span for a workflow node execution.""" + if predecessor_execution_id is not None: + predecessor_span = span_by_execution_id.get(predecessor_execution_id) + if predecessor_span is not None: + return predecessor_span + + graph_parent_execution_id = graph_parent_index.get(execution_id) + if graph_parent_execution_id is not None: + graph_parent_span = span_by_execution_id.get(graph_parent_execution_id) + if graph_parent_span is not None: + return graph_parent_span + + if structured_parent_execution_id is not None: + structured_parent_span = span_by_execution_id.get(structured_parent_execution_id) + if structured_parent_span is not None: + return structured_parent_span + + return workflow_span + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, @@ -189,6 +521,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.propagator = TraceContextTextMapPropagator() self.dify_trace_ids: set[str] = set() + self.root_span_carriers: dict[str, dict[str, str]] = {} + self.carrier: dict[str, str] = {} def trace(self, trace_info: BaseTraceInfo): logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info) @@ -235,13 +569,41 @@ class ArizePhoenixDataTrace(BaseTraceInstance): file_list=safe_json_dumps(file_list), query=trace_info.query or "", ) + workflow_session_id = _resolve_workflow_session_id(trace_info) + parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info) + logger.info( + "[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s " + "parent_workflow_run_id=%s parent_node_execution_id=%s resolved_session_id=%s", + trace_info.workflow_run_id, + trace_info.conversation_id, + parent_workflow_run_id, + parent_node_execution_id, + workflow_session_id, + ) - dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id - self.ensure_root_span(dify_trace_id) - root_span_context = self.propagator.extract(carrier=self.carrier) + if parent_node_execution_id: + workflow_parent_carrier = _resolve_published_parent_span_context(parent_node_execution_id) + else: + root_trace_id = _resolve_workflow_root_trace_id(trace_info) + workflow_root_span_name: str | None = trace_info.workflow_run_id + if not isinstance(workflow_root_span_name, str) or not workflow_root_span_name.strip(): + workflow_root_span_name = None + + workflow_parent_carrier = self.ensure_root_span( + root_trace_id, + root_span_name=workflow_root_span_name, + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + }, + ) + + workflow_span_context = self.propagator.extract(carrier=workflow_parent_carrier) workflow_span = self.tracer.start_span( - name=TraceTaskName.WORKFLOW_TRACE.value, + name=_resolve_workflow_span_name(trace_info), attributes={ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), @@ -249,10 +611,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.METADATA: safe_json_dumps(metadata), - SpanAttributes.SESSION_ID: trace_info.conversation_id or "", + SpanAttributes.SESSION_ID: workflow_session_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), - context=root_span_context, + context=workflow_span_context, ) # Through workflow_run_id, get all_nodes_execution using repository @@ -276,16 +638,50 @@ class ArizePhoenixDataTrace(BaseTraceInstance): workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( workflow_execution_id=trace_info.workflow_run_id ) + node_title_by_id = _build_node_title_by_id(trace_info) + execution_id_by_node_id = _build_execution_id_by_node_id(workflow_node_executions) + graph_parent_index = _build_graph_parent_index(workflow_node_executions) + node_execution_by_execution_id = { + _get_node_execution_id(node_execution): node_execution for node_execution in workflow_node_executions + } + span_by_execution_id: dict[str, Span] = {} + emitting_execution_ids: set[str] = set() + workflow_span_error: Exception | str | None = trace_info.error try: - for node_execution in workflow_node_executions: + + def emit_node_span(node_execution: _NodeExecutionLike) -> Span: + execution_id = _get_node_execution_id(node_execution) + existing_span = span_by_execution_id.get(execution_id) + if existing_span is not None: + return existing_span + + graph_parent_execution_id = graph_parent_index.get(execution_id) + structured_parent_execution_id = _resolve_structured_parent_execution_id( + node_execution, execution_id_by_node_id + ) + + if execution_id not in emitting_execution_ids: + emitting_execution_ids.add(execution_id) + try: + for parent_execution_id in (graph_parent_execution_id, structured_parent_execution_id): + if parent_execution_id is None or parent_execution_id == execution_id: + continue + if parent_execution_id in span_by_execution_id: + continue + parent_node_execution = node_execution_by_execution_id.get(parent_execution_id) + if parent_node_execution is not None: + emit_node_span(parent_node_execution) + finally: + emitting_execution_ids.discard(execution_id) + tenant_id = trace_info.tenant_id # Use from trace_info instead app_id = trace_info.metadata.get("app_id") # Use from trace_info instead inputs_value = node_execution.inputs or {} outputs_value = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() - elapsed_time = node_execution.elapsed_time + elapsed_time = node_execution.elapsed_time or 0 finished_at = created_at + timedelta(seconds=elapsed_time) process_data = node_execution.process_data or {} @@ -324,9 +720,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) - workflow_span_context = set_span_in_context(workflow_span) + parent_span = _resolve_node_parent( + execution_id=execution_id, + predecessor_execution_id=None, + structured_parent_execution_id=structured_parent_execution_id, + span_by_execution_id=span_by_execution_id, + graph_parent_index=graph_parent_index, + workflow_span=workflow_span, + ) + workflow_span_context = set_span_in_context(parent_span) node_span = self.tracer.start_span( - name=node_execution.node_type, + name=_resolve_workflow_node_span_name(node_execution, node_title_by_id), attributes={ SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value), @@ -334,13 +738,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.METADATA: safe_json_dumps(node_metadata), - SpanAttributes.SESSION_ID: trace_info.conversation_id or "", + SpanAttributes.SESSION_ID: workflow_session_id or "", }, start_time=datetime_to_nanos(created_at), context=workflow_span_context, ) - + span_by_execution_id[execution_id] = node_span + node_span_error: Exception | str | None = None try: + if node_execution.node_type == "tool": + parent_span_carrier: dict[str, str] = {} + with use_span(node_span, end_on_exit=False): + self.propagator.inject(carrier=parent_span_carrier) + _publish_parent_span_context(execution_id, parent_span_carrier) + if node_execution.node_type == "llm": llm_attributes: dict[str, Any] = { SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), @@ -362,17 +773,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) + except Exception as e: + node_span_error = e + raise finally: - if node_execution.status == WorkflowNodeExecutionStatus.FAILED: + if node_span_error is not None: + set_span_status(node_span, node_span_error) + elif node_execution.status == WorkflowNodeExecutionStatus.FAILED: set_span_status(node_span, node_execution.error) else: set_span_status(node_span) node_span.end(end_time=datetime_to_nanos(finished_at)) + return node_span + + for node_execution in workflow_node_executions: + emit_node_span(node_execution) + except Exception as e: + workflow_span_error = e + raise finally: - if trace_info.error: - set_span_status(workflow_span, trace_info.error) - else: - set_span_status(workflow_span) + set_span_status(workflow_span, workflow_span_error) workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time)) def message_trace(self, trace_info: MessageTraceInfo): @@ -735,22 +1155,39 @@ class ArizePhoenixDataTrace(BaseTraceInstance): finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) - def ensure_root_span(self, dify_trace_id: str | None): + def ensure_root_span( + self, + dify_trace_id: str | None, + *, + root_span_name: str | None = None, + root_span_attributes: Mapping[str, AttributeValue] | None = None, + ): """Ensure a unique root span exists for the given Dify trace ID.""" - if str(dify_trace_id) not in self.dify_trace_ids: - self.carrier: dict[str, str] = {} + trace_key = str(dify_trace_id) + if trace_key not in self.dify_trace_ids: + carrier: dict[str, str] = {} - root_span = self.tracer.start_span(name="Dify") - root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value) - root_span.set_attribute("dify_project_name", str(self.project)) - root_span.set_attribute("dify_trace_id", str(dify_trace_id)) + span_name = root_span_name.strip() if isinstance(root_span_name, str) and root_span_name.strip() else "Dify" + root_span_attributes_dict: dict[str, AttributeValue] = { + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, + "dify_project_name": str(self.project), + "dify_trace_id": trace_key, + } + if root_span_attributes: + root_span_attributes_dict.update(root_span_attributes) + + root_span = self.tracer.start_span(name=span_name, attributes=root_span_attributes_dict) with use_span(root_span, end_on_exit=False): - self.propagator.inject(carrier=self.carrier) + self.propagator.inject(carrier=carrier) set_span_status(root_span) root_span.end() - self.dify_trace_ids.add(str(dify_trace_id)) + self.dify_trace_ids.add(trace_key) + self.root_span_carriers[trace_key] = carrier + + self.carrier = self.root_span_carriers[trace_key] + return self.carrier def api_check(self): try: diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index e9ecc2e083..dd260aeee5 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -1,10 +1,21 @@ from datetime import UTC, datetime, timedelta -from typing import cast +from types import SimpleNamespace +from typing import Any, cast from unittest.mock import MagicMock, patch +import dify_trace_arize_phoenix.arize_phoenix_trace as arize_phoenix_trace_module import pytest from dify_trace_arize_phoenix.arize_phoenix_trace import ( + _NODE_TYPE_TO_SPAN_KIND, ArizePhoenixDataTrace, + _build_graph_parent_index, + _get_node_span_kind, + _phoenix_parent_span_redis_key, + _resolve_node_parent, + _resolve_published_parent_span_context, + _resolve_structured_parent_execution_id, + _resolve_workflow_parent_context, + _resolve_workflow_session_id, datetime_to_nanos, error_to_string, safe_json_dumps, @@ -13,6 +24,7 @@ from dify_trace_arize_phoenix.arize_phoenix_trace import ( wrap_span_metadata, ) from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry.sdk.trace import Tracer from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes from opentelemetry.trace import StatusCode @@ -24,8 +36,12 @@ from core.ops.entities.trace_entity import ( ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) +from core.ops.exceptions import PendingTraceParentContextError +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes # --- Helpers --- @@ -73,6 +89,80 @@ def _make_message_info(**kwargs): return MessageTraceInfo(**defaults) +def _get_start_span_call(start_span_mock, *, span_name: str): + for call in start_span_mock.call_args_list: + if call.kwargs.get("name") == span_name: + return call + raise AssertionError(f"Could not find start_span call with name={span_name!r}") + + +def _make_node_execution(**kwargs): + defaults = { + "node_type": "tool", + "status": "succeeded", + "inputs": {}, + "outputs": {}, + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": {}, + "metadata": {}, + "title": "Node", + "id": "node-execution-1", + "node_execution_id": "node-execution-1", + "node_id": "node-1", + "predecessor_node_id": None, + "iteration_id": None, + "loop_id": None, + "error": None, + } + defaults.update(kwargs) + node_execution = MagicMock() + for key, value in defaults.items(): + setattr(node_execution, key, value) + return node_execution + + +def _make_workflow_trace_info(**kwargs) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "workflow-1", + "tenant_id": "tenant-1", + "workflow_run_id": "workflow-run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"input": "value"}, + "workflow_run_outputs": {"output": "value"}, + "workflow_run_version": "1.0", + "total_tokens": 10, + "file_list": ["file-1"], + "query": "hello", + "metadata": {"app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowTraceInfo(**defaults) + + +def _make_workflow_node_trace_info(**kwargs) -> WorkflowNodeTraceInfo: + defaults = { + "workflow_id": "workflow-1", + "workflow_run_id": "workflow-run-1", + "tenant_id": "tenant-1", + "node_execution_id": "node-execution-1", + "node_id": "node-1", + "node_type": "tool", + "title": "Node 1", + "status": "succeeded", + "elapsed_time": 1.0, + "index": 1, + "metadata": {"app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowNodeTraceInfo(**defaults) + + # --- Utility Function Tests --- @@ -143,6 +233,258 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} +class TestGetNodeSpanKind: + def test_all_node_types_are_mapped_correctly(self): + special_mappings = { + BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, + BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL, + BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT, + } + + for node_type in BUILT_IN_NODE_TYPES: + expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) + actual_span_kind = _get_node_span_kind(node_type) + assert actual_span_kind == expected_span_kind, ( + f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." + ) + + def test_unknown_string_defaults_to_chain(self): + assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN + + def test_stale_dataset_retrieval_not_in_mapping(self): + assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND + + +class TestWorkflowSessionResolution: + def test_prefers_conversation_id(self): + info = _make_workflow_trace_info(conversation_id="conversation-1") + + assert _resolve_workflow_session_id(info) == "conversation-1" + + def test_nested_workflow_keeps_own_conversation_id_when_parent_context_exists(self): + info = _make_workflow_trace_info( + conversation_id="conversation-1", + metadata={ + "app_id": "app-1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + + assert _resolve_workflow_session_id(info) == "conversation-1" + + def test_uses_parent_workflow_run_id_for_nested_parent_trace_context(self): + info = _make_workflow_trace_info( + conversation_id=None, + metadata={ + "app_id": "app-1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + + assert _resolve_workflow_session_id(info) == "outer-workflow-run-1" + + def test_falls_back_to_workflow_run_id(self): + info = _make_workflow_trace_info(conversation_id=None) + + assert _resolve_workflow_session_id(info) == "workflow-run-1" + + def test_parent_context_helper_delegates_to_resolved_parent_context(self): + info = MagicMock() + info.resolved_parent_context = ("outer-workflow-run-1", "outer-node-execution-1") + + assert _resolve_workflow_parent_context(info) == info.resolved_parent_context + + +class TestPhoenixParentSpanBridgeHelpers: + def test_parent_span_redis_key_is_stable(self): + assert _phoenix_parent_span_redis_key("outer-node-execution-1") == ( + "trace:phoenix:parent_span:outer-node-execution-1" + ) + + def test_pending_parent_exception_exposes_execution_id(self): + error = PendingTraceParentContextError("outer-node-execution-1") + + assert error.parent_node_execution_id == "outer-node-execution-1" + assert "outer-node-execution-1" in str(error) + + def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch): + mock_redis = MagicMock() + mock_redis.get.return_value = '{"tracestate": "vendor=value"}' + monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis) + + with pytest.raises(ValueError, match="traceparent"): + _resolve_published_parent_span_context("outer-node-execution-1") + + @pytest.mark.parametrize( + "stored_payload", + [ + '{"traceparent": ""}', + '{"traceparent": "not-a-traceparent"}', + '{"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb"}', + ], + ) + def test_resolve_parent_span_context_rejects_malformed_traceparent(self, monkeypatch, stored_payload): + mock_redis = MagicMock() + mock_redis.get.return_value = stored_payload + monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis) + + with pytest.raises(ValueError, match="traceparent"): + _resolve_published_parent_span_context("outer-node-execution-1") + + +class TestWorkflowHierarchyHelpers: + def test_build_graph_parent_index_uses_predecessor_nodes_without_order_heuristics(self): + later_node = _make_workflow_node_trace_info( + node_execution_id="node-execution-3", + node_id="node-3", + predecessor_node_id="node-2", + index=3, + ) + root_node = _make_workflow_node_trace_info( + node_execution_id="node-execution-1", + node_id="node-1", + predecessor_node_id=None, + index=1, + ) + middle_node = _make_workflow_node_trace_info( + node_execution_id="node-execution-2", + node_id="node-2", + predecessor_node_id="node-1", + index=2, + ) + + graph_parent_index = _build_graph_parent_index([later_node, root_node, middle_node]) + + assert graph_parent_index == { + "node-execution-2": "node-execution-1", + "node-execution-3": "node-execution-2", + } + + def test_build_graph_parent_index_drops_ambiguous_parallel_like_predecessors(self): + first_parallel_node = _make_workflow_node_trace_info( + node_execution_id="parallel-node-execution-1", + node_id="parallel-node-1", + predecessor_node_id=None, + index=1, + parallel_id="parallel-1", + ) + second_parallel_node = _make_workflow_node_trace_info( + node_execution_id="parallel-node-execution-2", + node_id="parallel-node-1", + predecessor_node_id=None, + index=2, + parallel_id="parallel-2", + ) + child_node = _make_workflow_node_trace_info( + node_execution_id="child-node-execution-1", + node_id="child-node-1", + predecessor_node_id="parallel-node-1", + index=3, + ) + + graph_parent_index = _build_graph_parent_index([child_node, first_parallel_node, second_parallel_node]) + + assert graph_parent_index == {} + + def test_resolve_node_parent_prefers_predecessor_span(self): + workflow_span = MagicMock(name="workflow-span") + predecessor_span = MagicMock(name="predecessor-span") + graph_parent_span = MagicMock(name="graph-parent-span") + + parent = _resolve_node_parent( + execution_id="node-execution-2", + predecessor_execution_id="node-execution-1", + structured_parent_execution_id=None, + span_by_execution_id={ + "node-execution-1": predecessor_span, + "node-execution-0": graph_parent_span, + }, + graph_parent_index={ + "node-execution-2": "node-execution-0", + }, + workflow_span=workflow_span, + ) + + assert parent is predecessor_span + + def test_resolve_node_parent_falls_back_to_graph_parent_span(self): + workflow_span = MagicMock(name="workflow-span") + graph_parent_span = MagicMock(name="graph-parent-span") + + parent = _resolve_node_parent( + execution_id="node-execution-2", + predecessor_execution_id="missing-predecessor", + structured_parent_execution_id=None, + span_by_execution_id={ + "node-execution-0": graph_parent_span, + }, + graph_parent_index={ + "node-execution-2": "node-execution-0", + }, + workflow_span=workflow_span, + ) + + assert parent is graph_parent_span + + def test_resolve_node_parent_falls_back_to_workflow_span(self): + workflow_span = MagicMock(name="workflow-span") + + parent = _resolve_node_parent( + execution_id="node-execution-2", + predecessor_execution_id=None, + structured_parent_execution_id=None, + span_by_execution_id={}, + graph_parent_index={}, + workflow_span=workflow_span, + ) + + assert parent is workflow_span + + def test_resolve_structured_parent_execution_id_allows_body_nodes_to_use_enclosing_structure(self): + body_node = _make_workflow_node_trace_info( + node_execution_id="body-execution-1", + node_id="body-node-1", + node_type="tool", + loop_id="loop-node-1", + ) + + structured_parent_execution_id = _resolve_structured_parent_execution_id( + body_node, + execution_id_by_node_id={ + "loop-node-1": "loop-execution-1", + }, + ) + + assert structured_parent_execution_id == "loop-execution-1" + + def test_resolve_structured_parent_execution_id_reads_execution_metadata_dict_for_models(self): + body_node = SimpleNamespace( + node_execution_id="body-execution-1", + node_id="body-node-1", + execution_metadata_dict={ + "iteration_id": "iteration-node-1", + "loop_id": "loop-node-1", + }, + ) + + structured_parent_execution_id = _resolve_structured_parent_execution_id( + body_node, + execution_id_by_node_id={ + "iteration-node-1": "iteration-execution-1", + "loop-node-1": "loop-execution-1", + }, + ) + + assert structured_parent_execution_id == "iteration-execution-1" + + @patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") @patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): @@ -173,12 +515,17 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: + with ( + patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup, + patch("dify_trace_arize_phoenix.arize_phoenix_trace.redis_client", new=MagicMock()) as mock_redis, + ): mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") - return ArizePhoenixDataTrace(config) + instance = ArizePhoenixDataTrace(config) + cast(Any, instance)._mock_redis_client = mock_redis + yield instance def test_trace_dispatch(trace_instance): @@ -273,23 +620,821 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): @patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") -def test_message_trace_success(mock_db, trace_instance): +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_canonical_root_context_for_top_level_workflow( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info(message_id="message-1", workflow_run_id="workflow-run-1") + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + root_carrier = {} + root_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value=root_carrier) as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=root_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + mock_ensure_root_span.assert_called_once_with( + info.resolved_trace_id, + root_span_name="workflow-run-1", + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: "application/json", + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + }, + ) + mock_extract.assert_called_once_with(carrier=root_carrier) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + assert workflow_span_call.kwargs["context"] is root_context + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_workflow_run_id_for_root_span_and_populates_root_inputs_outputs( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + workflow_run_inputs={"prompt": "hello"}, + workflow_run_outputs={"result": "world"}, + metadata={ + "app_id": "app1", + "app_name": "Workflow Name", + }, + workflow_run_id="workflow-run-xyz", + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + root_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow-run-xyz") + assert root_span_call.kwargs["attributes"][SpanAttributes.INPUT_VALUE] == safe_json_dumps(info.workflow_run_inputs) + assert root_span_call.kwargs["attributes"][SpanAttributes.OUTPUT_VALUE] == safe_json_dumps( + info.workflow_run_outputs + ) + assert root_span_call.kwargs["attributes"][SpanAttributes.INPUT_MIME_TYPE] == "application/json" + assert root_span_call.kwargs["attributes"][SpanAttributes.OUTPUT_MIME_TYPE] == "application/json" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_falls_back_to_dify_name_when_workflow_run_id_is_blank( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + metadata={ + "app_id": "app1", + "app_name": "", + }, + workflow_run_id="", + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + root_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="Dify") + assert root_span_call.kwargs["attributes"]["dify_trace_id"] == "" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_reuses_upstream_parent_workflow_context_when_no_parent_node_execution_id_is_available( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + }, + }, + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + parent_carrier = {} + parent_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value=parent_carrier) as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=parent_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + mock_ensure_root_span.assert_called_once_with( + "outer-workflow-run-1", + root_span_name="workflow-run-1", + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: "application/json", + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + }, + ) + mock_extract.assert_called_once_with(carrier=parent_carrier) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + assert workflow_span_call.kwargs["context"] is parent_context + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_published_parent_node_context_for_nested_workflow( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + stored_carrier = '{"traceparent":"00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01"}' + trace_instance._mock_redis_client.get.return_value = stored_carrier + parent_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span") as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=parent_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + trace_instance._mock_redis_client.get.assert_called_once_with( + _phoenix_parent_span_redis_key("outer-node-execution-1") + ) + mock_ensure_root_span.assert_not_called() + mock_extract.assert_called_once_with( + carrier={"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01"} + ) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + assert workflow_span_call.kwargs["context"] is parent_context + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_raises_pending_parent_error_when_parent_node_context_is_missing( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + trace_instance._mock_redis_client.get.return_value = None + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span") as mock_ensure_root_span, + pytest.raises(PendingTraceParentContextError) as exc_info, + ): + trace_instance.workflow_trace(info) + + assert exc_info.value.parent_node_execution_id == "outer-node-execution-1" + trace_instance._mock_redis_client.get.assert_called_once_with( + _phoenix_parent_span_redis_key("outer-node-execution-1") + ) + mock_ensure_root_span.assert_not_called() + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_parent_workflow_run_id_for_workflow_and_nodes_when_nested_context_is_present( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + conversation_id=None, + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + }, + }, + ) + repo = MagicMock() + node_execution = MagicMock() + node_execution.node_type = "tool" + node_execution.status = "succeeded" + node_execution.inputs = {"tool_input": "value"} + node_execution.outputs = {"tool_output": "value"} + node_execution.created_at = _dt() + node_execution.elapsed_time = 1.0 + node_execution.process_data = {} + node_execution.metadata = {} + node_execution.title = "Tool node" + node_execution.id = "node-1" + node_execution.error = None + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_r1") + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_Tool node") + + assert workflow_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "outer-workflow-run-1" + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "outer-workflow-run-1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_falls_back_to_node_type_when_node_title_is_blank( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + node_execution = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="node-1", + node_type="tool", + title=" ", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool") + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "r1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_prefers_workflow_graph_node_title_over_execution_title( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + workflow_data={ + "graph": { + "nodes": [ + { + "id": "nested-tool-node", + "data": { + "type": "tool", + "title": "nested workflow tool", + }, + } + ] + } + } + ) + repo = MagicMock() + node_execution = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="nested-tool-node", + node_type="tool", + title="2", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_nested workflow tool") + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "r1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_keeps_nested_conversation_session_while_reusing_parent_root_context( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + conversation_id="conversation-1", + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + }, + }, + ) + repo = MagicMock() + node_execution = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + parent_carrier = {} + parent_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value=parent_carrier) as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=parent_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + mock_ensure_root_span.assert_called_once_with( + "outer-workflow-run-1", + root_span_name="workflow-run-1", + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: "application/json", + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + }, + ) + mock_extract.assert_called_once_with(carrier=parent_carrier) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_Node") + assert workflow_span_call.kwargs["context"] is parent_context + assert workflow_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-1" + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_publishes_tool_node_parent_span_context_to_redis( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + node_execution = _make_node_execution( + id="tool-execution-1", + node_execution_id="tool-execution-1", + node_id="tool-node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + tool_span = MagicMock(name="tool-span") + tool_span._context_label = "tool" + trace_instance.tracer.start_span.side_effect = [workflow_span, tool_span] + + def inject_side_effect(carrier): + carrier["traceparent"] = "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01" + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch.object(trace_instance.propagator, "inject", side_effect=inject_side_effect) as mock_inject, + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + mock_inject.assert_called_once() + trace_instance._mock_redis_client.setex.assert_called_once_with( + _phoenix_parent_span_redis_key("tool-execution-1"), + 300, + '{"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01"}', + ) + + +@pytest.mark.parametrize( + ("failing_step", "expected_message"), + [ + ("inject", "inject failed"), + ("publish", "publish failed"), + ], +) +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_cleans_up_tool_span_when_parent_context_publish_fails( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, + failing_step, + expected_message, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + node_execution = _make_node_execution( + id="tool-execution-1", + node_execution_id="tool-execution-1", + node_id="tool-node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + tool_span = MagicMock(name="tool-span") + tool_span._context_label = "tool" + trace_instance.tracer.start_span.side_effect = [workflow_span, tool_span] + + inject_side_effect = None + if failing_step == "inject": + inject_side_effect = RuntimeError(expected_message) + else: + trace_instance._mock_redis_client.setex.side_effect = RuntimeError(expected_message) + + def inject_side_effect(carrier): + carrier["traceparent"] = "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01" + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch.object(trace_instance.propagator, "inject", side_effect=inject_side_effect), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + pytest.raises(RuntimeError, match=expected_message), + ): + trace_instance.workflow_trace(info) + + tool_span.end.assert_called_once() + workflow_span.end.assert_called_once() + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_parents_serial_nodes_to_resolved_predecessor_span( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + second_node = _make_node_execution( + id="node-execution-2", + node_execution_id="node-execution-2", + node_id="node-2", + node_type="llm", + predecessor_node_id="node-1", + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + }, + ) + first_node = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [second_node, first_node] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + first_node_span = MagicMock(name="first-node-span") + first_node_span._context_label = "node-1" + second_node_span = MagicMock(name="second-node-span") + second_node_span._context_label = "node-2" + trace_instance.tracer.start_span.side_effect = [workflow_span, first_node_span, second_node_span] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + first_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_Node") + second_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="llm_Node") + assert first_node_call.kwargs["context"] == "context:workflow" + assert second_node_call.kwargs["context"] == "context:node-1" + + +@pytest.mark.parametrize( + ("enclosing_node_type", "structured_field"), + [ + ("loop", "loop_id"), + ("iteration", "iteration_id"), + ], +) +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_parents_structured_start_nodes_to_enclosing_structure_span( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, + enclosing_node_type, + structured_field, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + enclosing_node = _make_node_execution( + id=f"{enclosing_node_type}-execution-1", + node_execution_id=f"{enclosing_node_type}-execution-1", + node_id=f"{enclosing_node_type}-node-1", + node_type=enclosing_node_type, + ) + structured_kwargs = {structured_field: f"{enclosing_node_type}-node-1"} + start_node = _make_node_execution( + id="start-execution-1", + node_execution_id="start-execution-1", + node_id="start-node-1", + node_type="start", + **structured_kwargs, + ) + repo.get_by_workflow_execution.return_value = [start_node, enclosing_node] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + enclosing_node_span = MagicMock(name="enclosing-node-span") + enclosing_node_span._context_label = enclosing_node_type + start_node_span = MagicMock(name="start-node-span") + start_node_span._context_label = "start" + trace_instance.tracer.start_span.side_effect = [workflow_span, enclosing_node_span, start_node_span] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + start_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="start_Node") + assert start_node_call.kwargs["context"] == f"context:{enclosing_node_type}" + + +@pytest.mark.parametrize( + ("enclosing_node_type", "structured_field"), + [ + ("loop", "loop_id"), + ("iteration", "iteration_id"), + ], +) +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_keeps_duplicate_body_node_children_under_enclosing_structure( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, + enclosing_node_type, + structured_field, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + enclosing_node = _make_node_execution( + id=f"{enclosing_node_type}-execution-1", + node_execution_id=f"{enclosing_node_type}-execution-1", + node_id=f"{enclosing_node_type}-node-1", + node_type=enclosing_node_type, + ) + structured_kwargs = {structured_field: f"{enclosing_node_type}-node-1"} + repeated_body_node_1 = _make_node_execution( + id="body-execution-1", + node_execution_id="body-execution-1", + node_id="body-node-1", + node_type="tool", + **structured_kwargs, + ) + repeated_body_node_2 = _make_node_execution( + id="body-execution-2", + node_execution_id="body-execution-2", + node_id="body-node-1", + node_type="tool", + **structured_kwargs, + ) + child_node = _make_node_execution( + id="child-execution-1", + node_execution_id="child-execution-1", + node_id="child-node-1", + node_type="llm", + predecessor_node_id="body-node-1", + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + }, + **structured_kwargs, + ) + repo.get_by_workflow_execution.return_value = [ + child_node, + repeated_body_node_1, + repeated_body_node_2, + enclosing_node, + ] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + enclosing_node_span = MagicMock(name="enclosing-node-span") + enclosing_node_span._context_label = enclosing_node_type + child_node_span = MagicMock(name="child-node-span") + child_node_span._context_label = "child" + repeated_body_node_1_span = MagicMock(name="repeated-body-node-1-span") + repeated_body_node_1_span._context_label = "body-1" + repeated_body_node_2_span = MagicMock(name="repeated-body-node-2-span") + repeated_body_node_2_span._context_label = "body-2" + trace_instance.tracer.start_span.side_effect = [ + workflow_span, + enclosing_node_span, + child_node_span, + repeated_body_node_1_span, + repeated_body_node_2_span, + ] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + child_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="llm_Node") + assert child_node_call.kwargs["context"] == f"context:{enclosing_node_type}" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_falls_back_to_workflow_span_for_parallel_like_ambiguous_predecessors( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + child_node = _make_node_execution( + id="child-execution-1", + node_execution_id="child-execution-1", + node_id="child-node-1", + node_type="llm", + predecessor_node_id="parallel-node-1", + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + }, + ) + first_parallel_node = _make_node_execution( + id="parallel-execution-1", + node_execution_id="parallel-execution-1", + node_id="parallel-node-1", + node_type="tool", + parallel_id="parallel-1", + ) + second_parallel_node = _make_node_execution( + id="parallel-execution-2", + node_execution_id="parallel-execution-2", + node_id="parallel-node-1", + node_type="tool", + parallel_id="parallel-2", + ) + repo.get_by_workflow_execution.return_value = [child_node, first_parallel_node, second_parallel_node] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + child_node_span = MagicMock(name="child-node-span") + child_node_span._context_label = "child" + first_parallel_node_span = MagicMock(name="first-parallel-node-span") + first_parallel_node_span._context_label = "parallel-1" + second_parallel_node_span = MagicMock(name="second-parallel-node-span") + second_parallel_node_span._context_label = "parallel-2" + trace_instance.tracer.start_span.side_effect = [ + workflow_span, + child_node_span, + first_parallel_node_span, + second_parallel_node_span, + ] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + child_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="llm_Node") + assert child_node_call.kwargs["context"] == "context:workflow" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +def test_message_trace_keeps_conversation_id_as_session(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() info.message_data = MagicMock() - info.message_data.from_account_id = "acc1" + info.message_data.conversation_id = "conversation-2" + info.message_data.from_account_id = "acc2" info.message_data.from_end_user_id = None - info.message_data.query = "q" - info.message_data.answer = "a" - info.message_data.status = "s" - info.message_data.model_id = "m" - info.message_data.model_provider = "p" + info.message_data.query = "q2" + info.message_data.answer = "a2" + info.message_data.status = "s2" + info.message_data.model_id = "m2" + info.message_data.model_provider = "p2" info.message_data.message_metadata = "{}" info.message_data.error = None info.error = None + root_span = MagicMock() + message_span = MagicMock() + llm_span = MagicMock() + trace_instance.tracer.start_span.side_effect = [root_span, message_span, llm_span] + trace_instance.message_trace(info) - assert trace_instance.tracer.start_span.call_count >= 1 + + message_span_call = _get_start_span_call( + trace_instance.tracer.start_span, span_name=TraceTaskName.MESSAGE_TRACE.value + ) + assert message_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-2" @patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") @@ -397,3 +1542,30 @@ def test_api_check_success(trace_instance): def test_ensure_root_span_basic(trace_instance): trace_instance.ensure_root_span("tid") assert "tid" in trace_instance.dify_trace_ids + + +def test_ensure_root_span_uses_custom_name_and_attributes(trace_instance): + root_attributes = { + SpanAttributes.INPUT_VALUE: '{"input":"value"}', + SpanAttributes.OUTPUT_VALUE: '{"output":"value"}', + } + + trace_instance.ensure_root_span("tid", root_span_name="Workflow Name", root_span_attributes=root_attributes) + + trace_instance.tracer.start_span.assert_called_once_with( + name="Workflow Name", + attributes={ + SpanAttributes.OPENINFERENCE_SPAN_KIND: "CHAIN", + "dify_project_name": "p", + "dify_trace_id": "tid", + SpanAttributes.INPUT_VALUE: '{"input":"value"}', + SpanAttributes.OUTPUT_VALUE: '{"output":"value"}', + }, + ) + + +def test_ensure_root_span_falls_back_to_dify_name_when_custom_name_is_blank(trace_instance): + trace_instance.ensure_root_span("tid", root_span_name=" ") + + trace_instance.tracer.start_span.assert_called_once() + assert trace_instance.tracer.start_span.call_args.kwargs["name"] == "Dify" diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py deleted file mode 100644 index a01c63ae61..0000000000 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py +++ /dev/null @@ -1,36 +0,0 @@ -from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from openinference.semconv.trace import OpenInferenceSpanKindValues - -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes - - -class TestGetNodeSpanKind: - """Tests for _get_node_span_kind helper.""" - - def test_all_node_types_are_mapped_correctly(self): - """Ensure every built-in node type is mapped to the correct span kind.""" - # Mappings for node types that have a specialised span kind. - special_mappings = { - BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, - BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL, - BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT, - } - - # Test that every built-in node type is mapped to the correct span kind. - # Node types not in `special_mappings` should default to CHAIN. - for node_type in BUILT_IN_NODE_TYPES: - expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) - actual_span_kind = _get_node_span_kind(node_type) - assert actual_span_kind == expected_span_kind, ( - f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." - ) - - def test_unknown_string_defaults_to_chain(self): - """An unrecognised node type string should still return CHAIN.""" - assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN - - def test_stale_dataset_retrieval_not_in_mapping(self): - """The old 'dataset_retrieval' string was never a valid NodeType value; - make sure it is not present in the mapping dictionary.""" - assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index c95b8db078..49fe68ad7e 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,11 +1,31 @@ +""" +Celery task for asynchronous ops trace dispatch. + +Trace providers may report explicitly retryable dispatch failures through the +core retryable exception contract. The task preserves the payload file only +when Celery accepts the retry request; successful dispatches and terminal +failures clean up the stored payload. + +One concrete producer today is Phoenix nested workflow tracing. The outer +workflow tool span publishes a restorable parent span context asynchronously, +while the nested workflow trace may be picked up by Celery first. In that +ordering window, the provider raises a retryable core exception instead of +dropping the trace or emitting it under the wrong parent. The task intentionally +does not know that the provider is Phoenix; it only honors the core retryable +dispatch contract. +""" + import json import logging from celery import shared_task +from celery.exceptions import Retry from flask import current_app +from configs import dify_config from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY from core.ops.entities.trace_entity import trace_info_info_map +from core.ops.exceptions import RetryableTraceDispatchError from core.rag.models.document import Document from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -14,9 +34,17 @@ from models.workflow import WorkflowRun logger = logging.getLogger(__name__) +_RETRYABLE_TRACE_DISPATCH_LIMIT = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES +_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS -@shared_task(queue="ops_trace") -def process_trace_tasks(file_info): + +@shared_task( + queue="ops_trace", + bind=True, + max_retries=_RETRYABLE_TRACE_DISPATCH_LIMIT, + default_retry_delay=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS, +) +def process_trace_tasks(self, file_info): """ Async process trace tasks Usage: process_trace_tasks.delay(tasks_data) @@ -29,6 +57,7 @@ def process_trace_tasks(file_info): file_data = json.loads(storage.load(file_path)) trace_info = file_data.get("trace_info") trace_info_type = file_data.get("trace_info_type") + enterprise_trace_dispatched = bool(file_data.get("_enterprise_trace_dispatched")) trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) if trace_info.get("message_data"): @@ -38,6 +67,8 @@ def process_trace_tasks(file_info): if trace_info.get("documents"): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] + should_delete_file = True + try: trace_type = trace_info_info_map.get(trace_info_type) if trace_type: @@ -45,30 +76,66 @@ def process_trace_tasks(file_info): from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled - if is_ee_telemetry_enabled(): + if is_ee_telemetry_enabled() and not enterprise_trace_dispatched: from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace try: EnterpriseOtelTrace().trace(trace_info) except Exception: logger.exception("Enterprise trace failed for app_id: %s", app_id) + else: + file_data["_enterprise_trace_dispatched"] = True + enterprise_trace_dispatched = True if trace_instance: with current_app.app_context(): trace_instance.trace(trace_info) logger.info("Processing trace tasks success, app_id: %s", app_id) + except RetryableTraceDispatchError as e: + # Retryable dispatch failures represent a transient provider-side + # ordering gap, not corrupt payload data. Keep the payload only after + # Celery accepts the retry request; otherwise this attempt becomes a + # terminal failure and the stored file is cleaned up in `finally`. + # + # Enterprise telemetry runs before provider dispatch. If it already ran + # and provider dispatch asks for a retry, persist that private flag so + # the next attempt does not emit the same enterprise trace twice. + if self.request.retries >= _RETRYABLE_TRACE_DISPATCH_LIMIT: + logger.exception("Retryable trace dispatch budget exhausted, app_id: %s", app_id) + failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" + redis_client.incr(failed_key) + else: + logger.warning( + "Retryable trace dispatch failure, scheduling retry %s/%s for app_id %s: %s", + self.request.retries + 1, + _RETRYABLE_TRACE_DISPATCH_LIMIT, + app_id, + e, + ) + try: + if enterprise_trace_dispatched: + storage.save(file_path, json.dumps(file_data).encode("utf-8")) + raise self.retry(exc=e, countdown=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS) + except Retry: + should_delete_file = False + raise + except Exception: + logger.exception("Failed to schedule trace dispatch retry, app_id: %s", app_id) + failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" + redis_client.incr(failed_key) except Exception as e: logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) finally: - try: - storage.delete(file_path) - except Exception as e: - logger.warning( - "Failed to delete trace file %s for app_id %s: %s", - file_path, - app_id, - e, - ) + if should_delete_file: + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 0e9f8b6f35..2e4e469eb5 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,3 +1,4 @@ +import contextlib from types import SimpleNamespace from unittest.mock import MagicMock @@ -24,6 +25,76 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): assert WorkflowAppGenerator()._should_prepare_user_inputs(args) +def test_generate_includes_parent_trace_context_in_extras(monkeypatch): + generator = WorkflowAppGenerator() + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerator._bind_file_access_scope", + lambda *args, **kwargs: contextlib.nullcontext(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda *args, **kwargs: SimpleNamespace( + app_id="app-1", tenant_id="tenant-1", workflow_id="workflow-1", variables=[] + ), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.file_factory.build_from_mappings", lambda *args, **kwargs: [] + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.TraceQueueManager", MagicMock()) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + MagicMock(return_value=MagicMock()), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + MagicMock(return_value=MagicMock()), + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda *, user_inputs, **kwargs: user_inputs) + + captured = {} + + def fake_workflow_app_generate_entity(**kwargs): + captured["workflow_app_generate_entity_kwargs"] = kwargs + return SimpleNamespace(**kwargs) + + def fake_generate(**kwargs): + captured["application_generate_entity"] = kwargs["application_generate_entity"] + return {"data": {}} + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateEntity", fake_workflow_app_generate_entity + ) + monkeypatch.setattr(generator, "_generate", fake_generate) + + result = generator.generate( + app_model=SimpleNamespace(tenant_id="tenant-1", id="app-1"), + workflow=SimpleNamespace(features_dict={}), + user=SimpleNamespace(id="user-1", session_id="session-1"), + args={ + "inputs": {"query": "hello"}, + "files": [], + "external_trace_id": "trace-1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + invoke_from="service-api", + streaming=False, + call_depth=0, + ) + + assert result == {"data": {}} + extras = captured["workflow_app_generate_entity_kwargs"]["extras"] + assert extras["external_trace_id"] == "trace-1" + assert extras["parent_trace_context"].model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + def test_resume_delegates_to_generate(mocker: MockerFixture): generator = WorkflowAppGenerator() mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 7e87c088ce..9cefa97bef 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -7,6 +7,7 @@ import pytest from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.ops.ops_trace_manager import TraceTask, TraceTaskName from core.workflow.system_variables import SystemVariableKey, build_system_variables from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause @@ -217,6 +218,59 @@ class TestWorkflowPersistenceLayer: assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED assert trace_tasks + def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch): + trace_tasks: list[TraceTask] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, _, _, _ = _make_layer( + extras={ + "external_trace_id": "trace", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + trace_manager=trace_manager, + ) + layer._handle_graph_run_started() + + captured: dict[str, object] = {} + + def fake_workflow_trace( + self: TraceTask, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + total_tokens_override: int | None = None, + ): + captured["trace_type"] = self.trace_type + captured["external_trace_id"] = self.kwargs.get("external_trace_id") + captured["parent_trace_context"] = self.kwargs.get("parent_trace_context") + captured["workflow_run_id"] = workflow_run_id + return {"ok": True} + + monkeypatch.setattr(TraceTask, "workflow_trace", fake_workflow_trace) + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + assert trace_tasks + trace_task = trace_tasks[0] + assert trace_task.trace_type == TraceTaskName.WORKFLOW_TRACE + assert trace_task.kwargs["external_trace_id"] == "trace" + assert trace_task.kwargs["parent_trace_context"] == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + trace_task.execute() + + assert captured["trace_type"] == TraceTaskName.WORKFLOW_TRACE + assert captured["external_trace_id"] == "trace" + assert captured["parent_trace_context"] == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + def test_handle_graph_run_aborted_sets_status(self): layer, exec_repo, _, _ = _make_layer() layer._handle_graph_run_started() diff --git a/api/tests/unit_tests/core/helper/test_trace_id_helper.py b/api/tests/unit_tests/core/helper/test_trace_id_helper.py index 27bfe1af05..96e2d44730 100644 --- a/api/tests/unit_tests/core/helper/test_trace_id_helper.py +++ b/api/tests/unit_tests/core/helper/test_trace_id_helper.py @@ -1,6 +1,12 @@ import pytest -from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id +from core.helper.trace_id_helper import ( + ParentTraceContext, + extract_external_trace_id_from_args, + extract_parent_trace_context_from_args, + get_external_trace_id, + is_valid_trace_id, +) class DummyRequest: @@ -84,3 +90,92 @@ class TestTraceIdHelper: def test_extract_external_trace_id_from_args(self, args, expected): """Test extraction of external_trace_id from args mapping""" assert extract_external_trace_id_from_args(args) == expected + + @pytest.mark.parametrize( + ("args", "expected"), + [ + ( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + "parent_node_execution_id": "node-execution-1", + } + }, + { + "parent_trace_context": ParentTraceContext( + parent_workflow_run_id="workflow-run-1", + parent_node_execution_id="node-execution-1", + ) + }, + ), + ( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + } + }, + {}, + ), + ( + { + "parent_trace_context": { + "parent_node_execution_id": "node-execution-1", + } + }, + {}, + ), + ( + { + "parent_trace_context": { + "parent_workflow_run_id": 123, + "parent_node_execution_id": "node-execution-1", + } + }, + {}, + ), + ( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + "parent_node_execution_id": None, + } + }, + {}, + ), + ({}, {}), + ], + ) + def test_extract_parent_trace_context_from_args(self, args, expected): + """Test extraction of parent_trace_context from args mapping""" + assert extract_parent_trace_context_from_args(args) == expected + + def test_extract_parent_trace_context_returns_typed_context(self): + """Parent trace context is parsed into a Pydantic value object.""" + result = extract_parent_trace_context_from_args( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + "parent_node_execution_id": "node-execution-1", + } + } + ) + + assert result == { + "parent_trace_context": ParentTraceContext( + parent_workflow_run_id="workflow-run-1", + parent_node_execution_id="node-execution-1", + ) + } + + def test_extract_parent_trace_context_rejects_incomplete_typed_context(self): + """Typed parent trace context follows the same completeness rule as raw mappings.""" + result = extract_parent_trace_context_from_args( + { + "parent_trace_context": ParentTraceContext( + parent_workflow_run_id="workflow-run-1", + parent_node_execution_id=None, + ) + } + ) + + assert result == {} diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 72a73dd936..6c563b0912 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -147,6 +147,142 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke assert call_kwargs["pause_state_config"] is None +def test_workflow_tool_passes_parent_trace_context_from_runtime(monkeypatch: pytest.MonkeyPatch): + """Ensure nested workflow runtime metadata is forwarded as parent trace context.""" + tool = _build_tool() + tool.set_parent_trace_context( + parent_workflow_run_id="outer-workflow-run-1", + parent_node_execution_id="outer-node-execution-1", + ) + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert call_kwargs["args"]["parent_trace_context"].model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + +def test_workflow_tool_keeps_user_inputs_named_like_trace_runtime_keys(monkeypatch: pytest.MonkeyPatch): + """Ensure private trace context does not overwrite same-named workflow inputs.""" + tool = _build_tool() + tool.entity.parameters = [ + ToolParameter.get_simple_instance( + name="outer_workflow_run_id", + llm_description="User workflow input", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ), + ToolParameter.get_simple_instance( + name="outer_node_execution_id", + llm_description="User node input", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ), + ] + tool.set_parent_trace_context( + parent_workflow_run_id="outer-workflow-run-1", + parent_node_execution_id="outer-node-execution-1", + ) + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list( + tool.invoke( + "test_user", + { + "outer_workflow_run_id": "user-workflow-input", + "outer_node_execution_id": "user-node-input", + }, + ) + ) + + call_kwargs = generate_mock.call_args.kwargs + assert call_kwargs["args"]["inputs"]["outer_workflow_run_id"] == "user-workflow-input" + assert call_kwargs["args"]["inputs"]["outer_node_execution_id"] == "user-node-input" + assert call_kwargs["args"]["parent_trace_context"].model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + +def test_workflow_tool_can_clear_parent_trace_context(monkeypatch: pytest.MonkeyPatch): + """Ensure reused WorkflowTool instances do not keep stale parent trace context.""" + tool = _build_tool() + tool.set_parent_trace_context( + parent_workflow_run_id="outer-workflow-run-1", + parent_node_execution_id="outer-node-execution-1", + ) + tool.clear_parent_trace_context() + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "parent_trace_context" not in call_kwargs["args"] + + +@pytest.mark.parametrize( + "runtime_parameters", + [ + {}, + {"outer_workflow_run_id": "outer-workflow-run-1"}, + {"outer_node_execution_id": "outer-node-execution-1"}, + {"outer_workflow_run_id": None, "outer_node_execution_id": None}, + ], +) +def test_workflow_tool_omits_parent_trace_context_when_runtime_is_incomplete( + monkeypatch: pytest.MonkeyPatch, + runtime_parameters: dict[str, Any], +): + """Ensure incomplete runtime metadata does not leak parent trace context into generator args.""" + tool = _build_tool() + tool.runtime.runtime_parameters = runtime_parameters + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "parent_trace_context" not in call_kwargs["args"] + + def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" tool = _build_tool() diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index f17c95fc13..4d30746e5c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -15,9 +15,9 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.runtime import GraphRuntimeState from graphon.variables.segments import ArrayFileSegment -from tests.workflow_test_utils import build_test_graph_init_params +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool if TYPE_CHECKING: # pragma: no cover - imported for type checking only from graphon.nodes.tool.tool_node import ToolNode @@ -106,7 +106,7 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id")) + variable_pool = build_test_variable_pool(variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] @@ -234,3 +234,22 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [file_obj] + + +def test_tool_node_passes_node_execution_id_when_runtime_accepts_it(tool_node: ToolNode): + runtime_handle = ToolRuntimeHandle(raw=object()) + tool_node._runtime.get_runtime = MagicMock(return_value=runtime_handle) + tool_node.ensure_execution_id = MagicMock(return_value="node-execution-id") + + result = tool_node._get_tool_runtime( + variable_pool=tool_node.graph_runtime_state.variable_pool, + node_execution_id="node-execution-id", + ) + + assert result is runtime_handle + tool_node._runtime.get_runtime.assert_called_once_with( + node_id="node-instance", + node_data=tool_node.node_data, + variable_pool=tool_node.graph_runtime_state.variable_pool, + node_execution_id="node-execution-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index 438af211f3..aece73ce8c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -147,6 +147,69 @@ def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: Dify assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN +def test_get_runtime_stores_parent_trace_context_for_workflow_tools( + runtime: DifyToolNodeRuntime, +) -> None: + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables( + conversation_id="conversation-id", + workflow_execution_id="workflow-run-id", + ) + ) + workflow_runtime = MagicMock() + workflow_runtime.runtime.runtime_parameters = {} + node_data = ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.WORKFLOW, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_runtime): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=node_data, + variable_pool=variable_pool, + node_execution_id="node-execution-id", + ) + + assert tool_runtime.raw.parent_trace_context.model_dump() == { + "parent_workflow_run_id": "workflow-run-id", + "parent_node_execution_id": "node-execution-id", + } + assert workflow_runtime.runtime.runtime_parameters == {} + + +def test_get_runtime_leaves_non_workflow_tool_runtime_parameters_unchanged( + runtime: DifyToolNodeRuntime, +) -> None: + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables( + conversation_id="conversation-id", + workflow_execution_id="workflow-run-id", + ) + ) + builtin_runtime = MagicMock() + builtin_runtime.runtime.runtime_parameters = {} + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=builtin_runtime): + runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + node_execution_id="node-execution-id", + ) + + assert builtin_runtime.runtime.runtime_parameters == {} + + def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: tool_runtime = ToolRuntimeHandle( raw=SimpleNamespace( diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index 0d13151f42..d2925fd1a8 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -316,6 +316,81 @@ def test_dify_tool_file_manager_delegates_file_generator_lookup(monkeypatch: pyt get_file_generator.assert_called_once_with("tool-file-id") +def test_dify_tool_node_runtime_injects_outer_workflow_run_id_for_workflow_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={})) + get_runtime = MagicMock(return_value=runtime_tool) + monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime) + monkeypatch.setattr( + node_runtime, + "get_system_text", + lambda _pool, key: ( + "outer-workflow-run-id" if key == node_runtime.SystemVariableKey.WORKFLOW_EXECUTION_ID else None + ), + ) + + runtime = node_runtime.DifyToolNodeRuntime(_build_run_context()) + node_data = ToolNodeData( + title="Workflow Tool Node", + desc=None, + provider_id="workflow-provider-id", + provider_type=ToolProviderType.WORKFLOW, + provider_name="workflow-provider", + tool_name="workflow-tool", + tool_label="Workflow Tool", + tool_configurations={}, + tool_parameters={}, + ) + + handle = runtime.get_runtime( + node_id="tool-node", + node_data=node_data, + variable_pool=object(), + node_execution_id="node-execution-id", + ) + + assert handle.raw.tool is runtime_tool + assert handle.raw.parent_trace_context.model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-id", + "parent_node_execution_id": "node-execution-id", + } + assert runtime_tool.runtime.runtime_parameters == {} + get_runtime.assert_called_once() + + +def test_dify_tool_node_runtime_does_not_inject_outer_workflow_run_id_for_non_workflow_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={})) + get_runtime = MagicMock(return_value=runtime_tool) + monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime) + monkeypatch.setattr(node_runtime, "get_system_text", lambda _pool, _key: None) + + runtime = node_runtime.DifyToolNodeRuntime(_build_run_context()) + node_data = ToolNodeData( + title="Builtin Tool Node", + desc=None, + provider_id="builtin-provider-id", + provider_type=ToolProviderType.BUILT_IN, + provider_name="builtin-provider", + tool_name="builtin-tool", + tool_label="Builtin Tool", + tool_configurations={}, + tool_parameters={}, + ) + + handle = runtime.get_runtime( + node_id="tool-node", + node_data=node_data, + variable_pool=object(), + ) + + assert handle.raw.tool is runtime_tool + assert "outer_workflow_run_id" not in runtime_tool.runtime.runtime_parameters + get_runtime.assert_called_once() + + def test_dify_human_input_runtime_builds_debug_repository(monkeypatch: pytest.MonkeyPatch) -> None: repository = MagicMock() repository_cls = MagicMock(return_value=repository) diff --git a/api/tests/unit_tests/tasks/test_ops_trace_task.py b/api/tests/unit_tests/tasks/test_ops_trace_task.py new file mode 100644 index 0000000000..5844c55c04 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_ops_trace_task.py @@ -0,0 +1,301 @@ +import json +import sys +from contextlib import contextmanager +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest +from celery.exceptions import Retry + +from core.ops.entities.config_entity import OPS_TRACE_FAILED_KEY +from core.ops.exceptions import RetryableTraceDispatchError +from tasks.ops_trace_task import process_trace_tasks + + +@contextmanager +def fake_app_context(): + yield + + +class FakeCurrentApp: + def app_context(self): + return fake_app_context() + + +def _install_trace_manager( + trace_instance: MagicMock, + *, + enterprise_enabled: bool = False, + enterprise_trace_cls: MagicMock | None = None, +) -> dict[str, ModuleType]: + ops_trace_manager_module = ModuleType("core.ops.ops_trace_manager") + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id: str) -> MagicMock: + return trace_instance + + telemetry_module = ModuleType("extensions.ext_enterprise_telemetry") + telemetry_module.is_enabled = lambda: enterprise_enabled + + ops_trace_manager_module.OpsTraceManager = StubOpsTraceManager + modules = { + "core.ops.ops_trace_manager": ops_trace_manager_module, + "extensions.ext_enterprise_telemetry": telemetry_module, + } + if enterprise_trace_cls is not None: + enterprise_module = ModuleType("enterprise") + enterprise_telemetry_module = ModuleType("enterprise.telemetry") + enterprise_trace_module = ModuleType("enterprise.telemetry.enterprise_trace") + enterprise_trace_module.EnterpriseOtelTrace = enterprise_trace_cls + modules.update( + { + "enterprise": enterprise_module, + "enterprise.telemetry": enterprise_telemetry_module, + "enterprise.telemetry.enterprise_trace": enterprise_trace_module, + } + ) + return modules + + +def _make_payload() -> str: + return json.dumps({"trace_info": {}, "trace_info_type": None}) + + +def _decode_saved_payload(payload: bytes | str) -> dict[str, object]: + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + return json.loads(payload) + + +def _retryable_dispatch_error() -> RetryableTraceDispatchError: + return RetryableTraceDispatchError("transient trace dispatch failure") + + +def _run_task(file_info: dict[str, str], retries: int = 0) -> None: + process_trace_tasks.push_request(retries=retries) + try: + process_trace_tasks.run(file_info) + finally: + process_trace_tasks.pop_request() + + +def test_process_trace_tasks_retries_retryable_dispatch_failure_and_preserves_payload(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + retry_error = Retry() + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry, + pytest.raises(Retry), + ): + _run_task(file_info) + + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_not_called() + mock_incr.assert_not_called() + + +def test_process_trace_tasks_marks_enterprise_trace_dispatched_before_retryable_dispatch_retry(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + retry_error = Retry() + enterprise_tracer = MagicMock() + enterprise_trace_cls = MagicMock(return_value=enterprise_tracer) + + with ( + patch.dict( + sys.modules, + _install_trace_manager( + trace_instance, + enterprise_enabled=True, + enterprise_trace_cls=enterprise_trace_cls, + ), + ), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.save") as mock_save, + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry, + pytest.raises(Retry), + ): + _run_task(file_info) + + enterprise_tracer.trace.assert_called_once_with({}) + saved_path, saved_payload = mock_save.call_args.args + assert saved_path == "ops_trace/app-id/file-id.json" + assert _decode_saved_payload(saved_payload)["_enterprise_trace_dispatched"] is True + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_not_called() + mock_incr.assert_not_called() + + +def test_process_trace_tasks_does_not_mark_failed_enterprise_trace_as_dispatched_before_retry(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + retry_error = Retry() + enterprise_tracer = MagicMock() + enterprise_tracer.trace.side_effect = RuntimeError("enterprise trace failed") + enterprise_trace_cls = MagicMock(return_value=enterprise_tracer) + + with ( + patch.dict( + sys.modules, + _install_trace_manager( + trace_instance, + enterprise_enabled=True, + enterprise_trace_cls=enterprise_trace_cls, + ), + ), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.save") as mock_save, + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry, + pytest.raises(Retry), + ): + _run_task(file_info) + + enterprise_tracer.trace.assert_called_once_with({}) + mock_save.assert_not_called() + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_not_called() + mock_incr.assert_not_called() + + +def test_process_trace_tasks_skips_enterprise_trace_when_retry_payload_was_already_dispatched(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + enterprise_trace_cls = MagicMock() + payload = json.dumps({"trace_info": {}, "trace_info_type": None, "_enterprise_trace_dispatched": True}) + + with ( + patch.dict( + sys.modules, + _install_trace_manager( + trace_instance, + enterprise_enabled=True, + enterprise_trace_cls=enterprise_trace_cls, + ), + ), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=payload), + patch("tasks.ops_trace_task.storage.save") as mock_save, + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + ): + _run_task(file_info) + + enterprise_trace_cls.assert_not_called() + trace_instance.trace.assert_called_once_with({}) + mock_save.assert_not_called() + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_not_called() + + +def test_process_trace_tasks_default_retry_window_covers_parent_span_context_ttl(): + assert process_trace_tasks.max_retries * process_trace_tasks.default_retry_delay >= 300 + + +def test_process_trace_tasks_deletes_payload_on_success(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + ): + _run_task(file_info) + + trace_instance.trace.assert_called_once_with({}) + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_not_called() + + +def test_process_trace_tasks_deletes_payload_and_counts_terminal_failure(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + trace_instance.trace.side_effect = RuntimeError("trace failed") + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + ): + _run_task(file_info) + + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id") + + +def test_process_trace_tasks_treats_retry_enqueue_failure_as_terminal_failure(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + retry_enqueue_error = RuntimeError("retry enqueue failed") + trace_instance.trace.side_effect = pending_error + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_enqueue_error) as mock_retry, + ): + _run_task(file_info) + + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id") + + +def test_process_trace_tasks_deletes_payload_and_counts_exhausted_retryable_dispatch_failure(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry") as mock_retry, + ): + _run_task(file_info, retries=process_trace_tasks.max_retries) + + mock_retry.assert_not_called() + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id") diff --git a/docker/.env.example b/docker/.env.example index d9891d842a..5a012973c0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1,5 +1,6 @@ # ------------------------------------------------------------------ # Essential defaults for Docker Compose deployments. +# Only include variables required for services to start. # # For a default deployment, copy this file to .env and run: # docker compose up -d diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index af1c3ce74e..80cfe42c38 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -71,6 +71,8 @@ LOG_TZ=UTC DEBUG=false FLASK_DEBUG=false ENABLE_REQUEST_LOGGING=False +OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60 +OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5 WORKFLOW_LOG_CLEANUP_ENABLED=false WORKFLOW_LOG_RETENTION_DAYS=30 WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 From 6164408da15f5043ad0c4b8ee5378800ddefe9ea Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 11 May 2026 16:42:09 +0800 Subject: [PATCH 2/4] fix(web): align tag filter dropdown icon (#36041) --- web/features/tag-management/components/tag-filter.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/features/tag-management/components/tag-filter.tsx b/web/features/tag-management/components/tag-filter.tsx index bca8465730..2c6938dc4d 100644 --- a/web/features/tag-management/components/tag-filter.tsx +++ b/web/features/tag-management/components/tag-filter.tsx @@ -78,11 +78,11 @@ export const TagFilter = ({ !!value.length && 'pr-6 shadow-xs', )} > - + - + {!value.length && t('tag.placeholder', { ns: 'common' })} {!!value.length && currentTagName} From a60cb3b80098234534f245c45c8239e087d20bf9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 11 May 2026 18:17:12 +0900 Subject: [PATCH 3/4] chore: port WorkflowComment (#36039) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/models/comment.py | 65 ++++++++++++------- .../unit_tests/models/test_comment_models.py | 32 ++++++++- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/api/models/comment.py b/api/models/comment.py index 5d4a08e783..6d151fe13d 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -1,19 +1,22 @@ """Workflow comment models.""" +from __future__ import annotations + from datetime import datetime -from typing import Optional import sqlalchemy as sa from sqlalchemy import Index, func from sqlalchemy.orm import Mapped, mapped_column, relationship +from models.base import TypeBase + from .account import Account -from .base import Base, gen_uuidv7_string +from .base import gen_uuidv7_string from .engine import db from .types import StringUUID -class WorkflowComment(Base): +class WorkflowComment(TypeBase): """Workflow comment model for canvas commenting functionality. Comments are associated with apps rather than specific workflow versions, @@ -42,27 +45,33 @@ class WorkflowComment(Base): Index("workflow_comments_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position_x: Mapped[float] = mapped_column(sa.Float) position_y: Mapped[float] = mapped_column(sa.Float) content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) - resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) - resolved_by: Mapped[str | None] = mapped_column(StringUUID) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime, default=None) + resolved_by: Mapped[str | None] = mapped_column(StringUUID, default=None) + resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) # Relationships - replies: Mapped[list["WorkflowCommentReply"]] = relationship( - "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + replies: Mapped[list[WorkflowCommentReply]] = relationship( + lambda: WorkflowCommentReply, back_populates="comment", cascade="all, delete-orphan", init=False ) - mentions: Mapped[list["WorkflowCommentMention"]] = relationship( - "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + mentions: Mapped[list[WorkflowCommentMention]] = relationship( + lambda: WorkflowCommentMention, back_populates="comment", cascade="all, delete-orphan", init=False ) @property @@ -131,7 +140,7 @@ class WorkflowComment(Base): return participants -class WorkflowCommentReply(Base): +class WorkflowCommentReply(TypeBase): """Workflow comment reply model. Attributes: @@ -149,18 +158,24 @@ class WorkflowCommentReply(Base): Index("comment_replies_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False) comment_id: Mapped[str] = mapped_column( StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) # Relationships - comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="replies", init=False) @property def created_by_account(self): @@ -174,7 +189,7 @@ class WorkflowCommentReply(Base): self._created_by_account_cache = account -class WorkflowCommentMention(Base): +class WorkflowCommentMention(TypeBase): """Workflow comment mention model. Mentions are only for internal accounts since end users @@ -194,18 +209,18 @@ class WorkflowCommentMention(Base): Index("comment_mentions_user_idx", "mentioned_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False) comment_id: Mapped[str] = mapped_column( StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) - reply_id: Mapped[str | None] = mapped_column( - StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True - ) mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True, default=None + ) # Relationships - comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") - reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="mentions", init=False) + reply: Mapped[WorkflowCommentReply | None] = relationship(lambda: WorkflowCommentReply, init=False) @property def mentioned_user_account(self): diff --git a/api/tests/unit_tests/models/test_comment_models.py b/api/tests/unit_tests/models/test_comment_models.py index 277335cbef..8c8985aff8 100644 --- a/api/tests/unit_tests/models/test_comment_models.py +++ b/api/tests/unit_tests/models/test_comment_models.py @@ -4,7 +4,15 @@ from models.comment import WorkflowComment, WorkflowCommentMention, WorkflowComm def test_workflow_comment_account_properties_and_cache() -> None: - comment = WorkflowComment(created_by="user-1", resolved_by="user-2", content="hello", position_x=1, position_y=2) + comment = WorkflowComment( + created_by="user-1", + resolved_by="user-2", + content="hello", + position_x=1, + position_y=2, + tenant_id="xxx", + app_id="yyy", + ) created_account = Mock(id="user-1") resolved_account = Mock(id="user-2") @@ -21,6 +29,8 @@ def test_workflow_comment_account_properties_and_cache() -> None: get_mock.assert_not_called() comment_without_resolver = WorkflowComment( + tenant_id="xxx", + app_id="yyy", created_by="user-1", resolved_by=None, content="hello", @@ -37,7 +47,15 @@ def test_workflow_comment_counts_and_participants() -> None: reply_2 = WorkflowCommentReply(comment_id="comment-1", content="reply-2", created_by="user-2") mention_1 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3") mention_2 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-4") - comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2) + comment = WorkflowComment( + created_by="user-1", + resolved_by=None, + content="hello", + position_x=1, + position_y=2, + tenant_id="xxx", + app_id="yyy", + ) comment.replies = [reply_1, reply_2] comment.mentions = [mention_1, mention_2] @@ -63,7 +81,15 @@ def test_workflow_comment_counts_and_participants() -> None: def test_workflow_comment_participants_use_cached_accounts() -> None: reply = WorkflowCommentReply(comment_id="comment-1", content="reply-1", created_by="user-2") mention = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3") - comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2) + comment = WorkflowComment( + created_by="user-1", + resolved_by=None, + content="hello", + position_x=1, + position_y=2, + tenant_id="xxx", + app_id="yyy", + ) comment.replies = [reply] comment.mentions = [mention] From 59dab7deac1810dd42816cc52450fb2faba0fb2f Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 11 May 2026 17:58:19 +0800 Subject: [PATCH 4/4] refactor(apps): simplify query state and debounce URL writes (#36043) --- .../skills/how-to-write-component/SKILL.md | 14 +- .../components/apps/__tests__/list.spec.tsx | 52 ++-- web/app/components/apps/constants.ts | 1 + .../__tests__/use-apps-query-state.spec.tsx | 291 ++++++++---------- .../apps/hooks/use-apps-query-state.ts | 101 +++--- web/app/components/apps/list.tsx | 83 ++--- 6 files changed, 233 insertions(+), 309 deletions(-) create mode 100644 web/app/components/apps/constants.ts diff --git a/.agents/skills/how-to-write-component/SKILL.md b/.agents/skills/how-to-write-component/SKILL.md index f33a9dd75e..ac77112993 100644 --- a/.agents/skills/how-to-write-component/SKILL.md +++ b/.agents/skills/how-to-write-component/SKILL.md @@ -55,9 +55,17 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible. - Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. -## Navigation, Effects, And Performance +## You Might Not Need An Effect + +- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration. +- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive. +- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known. +- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render. +- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary. +- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components. +- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow. + +## Navigation And Performance - Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission. -- Treat `useEffect` as a last resort. First try deriving values during render, moving event-driven work into handlers, or using existing hooks/APIs for persistence, subscriptions, media queries, timers, and DOM sync. -- Do not use `useEffect` directly in components. If unavoidable, encapsulate it in a purpose-built hook so the component consumes a declarative API. - Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason. diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 41d2ccbc80..0c6f1702d7 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -48,16 +48,24 @@ vi.mock('@/context/app-context', () => ({ }), })) -const mockSetQuery = vi.fn() +const mockSetKeywords = vi.fn() +const mockSetTagIDs = vi.fn() +const mockSetIsCreatedByMe = vi.fn() +const mockSetCategory = vi.fn() const mockQueryState = { + category: 'all', tagIDs: [] as string[], keywords: '', isCreatedByMe: false, } vi.mock('../hooks/use-apps-query-state', () => ({ - default: () => ({ + isAppListCategory: (value: string) => value === 'all' || Object.values(AppModeEnum).includes(value as AppModeEnum), + useAppsQueryState: () => ({ query: mockQueryState, - setQuery: mockSetQuery, + setCategory: mockSetCategory, + setKeywords: mockSetKeywords, + setTagIDs: mockSetTagIDs, + setIsCreatedByMe: mockSetIsCreatedByMe, }), })) @@ -244,6 +252,7 @@ describe('List', () => { mockServiceState.hasNextPage = false mockServiceState.isLoading = false mockServiceState.isFetchingNextPage = false + mockQueryState.category = 'all' mockQueryState.tagIDs = [] mockQueryState.keywords = '' mockQueryState.isCreatedByMe = false @@ -317,25 +326,21 @@ describe('List', () => { }) describe('Tab Navigation', () => { - it('should update URL when workflow tab is clicked', async () => { - const { onUrlUpdate } = renderList() + it('should update category when workflow tab is clicked', () => { + renderList() fireEvent.click(screen.getByText('app.types.workflow')) - await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - expect(lastCall.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW) + expect(mockSetCategory).toHaveBeenCalledWith(AppModeEnum.WORKFLOW) }) - it('should update URL when all tab is clicked', async () => { - const { onUrlUpdate } = renderList('?category=workflow') + it('should update category when all tab is clicked', () => { + mockQueryState.category = AppModeEnum.WORKFLOW + renderList() fireEvent.click(screen.getByText('app.types.all')) - await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - // nuqs removes the default value ('all') from URL params - expect(lastCall.searchParams.has('category')).toBe(false) + expect(mockSetCategory).toHaveBeenCalledWith('all') }) }) @@ -351,7 +356,7 @@ describe('List', () => { const input = screen.getByRole('textbox') fireEvent.change(input, { target: { value: 'test search' } }) - expect(mockSetQuery).toHaveBeenCalled() + expect(mockSetKeywords).toHaveBeenCalledWith('test search') }) it('should handle search clear button click', () => { @@ -364,7 +369,7 @@ describe('List', () => { if (clearButton) fireEvent.click(clearButton) - expect(mockSetQuery).toHaveBeenCalled() + expect(mockSetKeywords).toHaveBeenCalledWith('') }) }) @@ -373,8 +378,9 @@ describe('List', () => { mockQueryState.tagIDs = ['tag-1'] mockQueryState.keywords = 'sales' mockQueryState.isCreatedByMe = true + mockQueryState.category = AppModeEnum.WORKFLOW - renderList('?category=workflow') + renderList() const options = mockAppListInfiniteOptions.mock.calls.at(-1)?.[0] as AppListInfiniteOptions @@ -412,7 +418,7 @@ describe('List', () => { const checkbox = screen.getByTestId('checkbox-undefined') fireEvent.click(checkbox) - expect(mockSetQuery).toHaveBeenCalled() + expect(mockSetIsCreatedByMe).toHaveBeenCalledWith(true) }) }) @@ -506,8 +512,8 @@ describe('List', () => { expect(screen.getByText('app.types.completion'))!.toBeInTheDocument() }) - it('should update URL for each app type tab click', async () => { - const { onUrlUpdate } = renderList() + it('should update category for each app type tab click', () => { + renderList() const appTypeTexts = [ { mode: AppModeEnum.WORKFLOW, text: 'app.types.workflow' }, @@ -518,11 +524,9 @@ describe('List', () => { ] for (const { mode, text } of appTypeTexts) { - onUrlUpdate.mockClear() + mockSetCategory.mockClear() fireEvent.click(screen.getByText(text)) - await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - expect(lastCall.searchParams.get('category')).toBe(mode) + expect(mockSetCategory).toHaveBeenCalledWith(mode) } }) }) diff --git a/web/app/components/apps/constants.ts b/web/app/components/apps/constants.ts new file mode 100644 index 0000000000..95c3dcff42 --- /dev/null +++ b/web/app/components/apps/constants.ts @@ -0,0 +1 @@ +export const APP_LIST_SEARCH_DEBOUNCE_MS = 500 diff --git a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx index 4b0c63f580..782f6ec353 100644 --- a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx +++ b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx @@ -1,6 +1,8 @@ import { act, waitFor } from '@testing-library/react' import { renderHookWithNuqs } from '@/test/nuqs-testing' -import useAppsQueryState from '../use-apps-query-state' +import { AppModeEnum } from '@/types/app' +import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../../constants' +import { useAppsQueryState } from '../use-apps-query-state' const renderWithAdapter = (searchParams = '') => { return renderHookWithNuqs(() => useAppsQueryState(), { searchParams }) @@ -11,214 +13,161 @@ describe('useAppsQueryState', () => { vi.clearAllMocks() }) - describe('Initialization', () => { - it('should expose query and setQuery when initialized', () => { - const { result } = renderWithAdapter() + it('should expose app list query state actions', () => { + const { result } = renderWithAdapter() - expect(result.current.query).toBeDefined() - expect(typeof result.current.setQuery).toBe('function') + expect(result.current.query).toEqual({ + category: 'all', + tagIDs: [], + keywords: '', + isCreatedByMe: false, }) + expect(typeof result.current.setCategory).toBe('function') + expect(typeof result.current.setKeywords).toBe('function') + expect(typeof result.current.setTagIDs).toBe('function') + expect(typeof result.current.setIsCreatedByMe).toBe('function') + }) - it('should default to empty filters when search params are missing', () => { - const { result } = renderWithAdapter() + it('should parse app list filters from URL', () => { + const { result } = renderWithAdapter( + '?category=workflow&tagIDs=tag1;tag2&keywords=search+term&isCreatedByMe=true', + ) - expect(result.current.query.tagIDs).toBeUndefined() - expect(result.current.query.keywords).toBeUndefined() - expect(result.current.query.isCreatedByMe).toBe(false) + expect(result.current.query).toEqual({ + category: AppModeEnum.WORKFLOW, + tagIDs: ['tag1', 'tag2'], + keywords: 'search term', + isCreatedByMe: true, }) }) - describe('Parsing search params', () => { - it('should parse tagIDs when URL includes tagIDs', () => { - const { result } = renderWithAdapter('?tagIDs=tag1;tag2;tag3') + it('should update category URL state', async () => { + const { result, onUrlUpdate } = renderWithAdapter() - expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2', 'tag3']) + act(() => { + result.current.setCategory(AppModeEnum.WORKFLOW) }) - it('should parse keywords when URL includes keywords', () => { - const { result } = renderWithAdapter('?keywords=search+term') - - expect(result.current.query.keywords).toBe('search term') - }) - - it('should parse isCreatedByMe when URL includes true value', () => { - const { result } = renderWithAdapter('?isCreatedByMe=true') - - expect(result.current.query.isCreatedByMe).toBe(true) - }) - - it('should parse all params when URL includes multiple filters', () => { - const { result } = renderWithAdapter( - '?tagIDs=tag1;tag2&keywords=test&isCreatedByMe=true', - ) - - expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2']) - expect(result.current.query.keywords).toBe('test') - expect(result.current.query.isCreatedByMe).toBe(true) - }) + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(result.current.query.category).toBe(AppModeEnum.WORKFLOW) + expect(update.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW) + expect(update.options.history).toBe('push') }) - describe('Updating query state', () => { - it('should update keywords when setQuery receives keywords', () => { - const { result } = renderWithAdapter() + it('should remove category from URL when set to all', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?category=workflow') - act(() => { - result.current.setQuery({ keywords: 'new search' }) - }) - - expect(result.current.query.keywords).toBe('new search') + act(() => { + result.current.setCategory('all') }) - it('should update tagIDs when setQuery receives tagIDs', () => { - const { result } = renderWithAdapter() - - act(() => { - result.current.setQuery({ tagIDs: ['tag1', 'tag2'] }) - }) - - expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2']) - }) - - it('should update isCreatedByMe when setQuery receives true', () => { - const { result } = renderWithAdapter() - - act(() => { - result.current.setQuery({ isCreatedByMe: true }) - }) - - expect(result.current.query.isCreatedByMe).toBe(true) - }) - - it('should support partial updates when setQuery uses callback', () => { - const { result } = renderWithAdapter() - - act(() => { - result.current.setQuery({ keywords: 'initial' }) - }) - - act(() => { - result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true })) - }) - - expect(result.current.query.keywords).toBe('initial') - expect(result.current.query.isCreatedByMe).toBe(true) - }) + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(result.current.query.category).toBe('all') + expect(update.searchParams.has('category')).toBe(false) }) - describe('URL synchronization', () => { - it('should sync keywords to URL when keywords change', async () => { + it('should update keywords state immediately while debouncing URL writes', async () => { + vi.useFakeTimers() + try { const { result, onUrlUpdate } = renderWithAdapter() act(() => { - result.current.setQuery({ keywords: 'search' }) + result.current.setKeywords('search') }) - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] + expect(result.current.query.keywords).toBe('search') + expect(onUrlUpdate).not.toHaveBeenCalled() + + await act(async () => { + await vi.advanceTimersByTimeAsync(APP_LIST_SEARCH_DEBOUNCE_MS + 100) + }) + + expect(onUrlUpdate).toHaveBeenCalled() + const update = onUrlUpdate.mock.calls.at(-1)![0] expect(update.searchParams.get('keywords')).toBe('search') - expect(update.options.history).toBe('push') - }) + } + finally { + vi.useRealTimers() + } + }) - it('should sync tagIDs to URL when tagIDs change', async () => { - const { result, onUrlUpdate } = renderWithAdapter() - - act(() => { - result.current.setQuery({ tagIDs: ['tag1', 'tag2'] }) - }) - - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2') - }) - - it('should sync isCreatedByMe to URL when enabled', async () => { - const { result, onUrlUpdate } = renderWithAdapter() - - act(() => { - result.current.setQuery({ isCreatedByMe: true }) - }) - - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - expect(update.searchParams.get('isCreatedByMe')).toBe('true') - }) - - it('should remove keywords from URL when keywords are cleared', async () => { + it('should remove keywords from URL when cleared', async () => { + vi.useFakeTimers() + try { const { result, onUrlUpdate } = renderWithAdapter('?keywords=existing') act(() => { - result.current.setQuery({ keywords: '' }) + result.current.setKeywords('') }) - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] + expect(result.current.query.keywords).toBe('') + + await act(async () => { + await vi.advanceTimersByTimeAsync(APP_LIST_SEARCH_DEBOUNCE_MS + 100) + }) + + expect(onUrlUpdate).toHaveBeenCalled() + const update = onUrlUpdate.mock.calls.at(-1)![0] expect(update.searchParams.has('keywords')).toBe(false) - }) - - it('should remove tagIDs from URL when tagIDs are empty', async () => { - const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2') - - act(() => { - result.current.setQuery({ tagIDs: [] }) - }) - - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - expect(update.searchParams.has('tagIDs')).toBe(false) - }) - - it('should remove isCreatedByMe from URL when disabled', async () => { - const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true') - - act(() => { - result.current.setQuery({ isCreatedByMe: false }) - }) - - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - expect(update.searchParams.has('isCreatedByMe')).toBe(false) - }) + } + finally { + vi.useRealTimers() + } }) - describe('Edge cases', () => { - it('should treat empty tagIDs as empty list when URL param is empty', () => { - const { result } = renderWithAdapter('?tagIDs=') + it('should update tag filter URL state', async () => { + const { result, onUrlUpdate } = renderWithAdapter() - expect(result.current.query.tagIDs).toEqual([]) + act(() => { + result.current.setTagIDs(['tag1', 'tag2']) }) - it('should treat empty keywords as undefined when URL param is empty', () => { - const { result } = renderWithAdapter('?keywords=') - - expect(result.current.query.keywords).toBeUndefined() - }) - - it('should decode keywords with spaces when URL contains encoded spaces', () => { - const { result } = renderWithAdapter('?keywords=test+with+spaces') - - expect(result.current.query.keywords).toBe('test with spaces') - }) + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2']) + expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2') + expect(update.options.history).toBe('push') }) - describe('Integration scenarios', () => { - it('should keep accumulated filters when updates are sequential', () => { - const { result } = renderWithAdapter() + it('should remove tagIDs from URL when empty', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2') - act(() => { - result.current.setQuery({ keywords: 'first' }) - }) - - act(() => { - result.current.setQuery(prev => ({ ...prev, tagIDs: ['tag1'] })) - }) - - act(() => { - result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true })) - }) - - expect(result.current.query.keywords).toBe('first') - expect(result.current.query.tagIDs).toEqual(['tag1']) - expect(result.current.query.isCreatedByMe).toBe(true) + act(() => { + result.current.setTagIDs([]) }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(result.current.query.tagIDs).toEqual([]) + expect(update.searchParams.has('tagIDs')).toBe(false) + }) + + it('should update created-by-me URL state', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.setIsCreatedByMe(true) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(result.current.query.isCreatedByMe).toBe(true) + expect(update.searchParams.get('isCreatedByMe')).toBe('true') + expect(update.options.history).toBe('push') + }) + + it('should remove isCreatedByMe from URL when disabled', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true') + + act(() => { + result.current.setIsCreatedByMe(false) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(result.current.query.isCreatedByMe).toBe(false) + expect(update.searchParams.has('isCreatedByMe')).toBe(false) }) }) diff --git a/web/app/components/apps/hooks/use-apps-query-state.ts b/web/app/components/apps/hooks/use-apps-query-state.ts index ecf7707e8a..a0109eb061 100644 --- a/web/app/components/apps/hooks/use-apps-query-state.ts +++ b/web/app/components/apps/hooks/use-apps-query-state.ts @@ -1,57 +1,56 @@ -import { parseAsArrayOf, parseAsBoolean, parseAsString, useQueryStates } from 'nuqs' +import { debounce, parseAsArrayOf, parseAsBoolean, parseAsString, parseAsStringLiteral, useQueryStates } from 'nuqs' import { useCallback, useMemo } from 'react' +import { AppModes } from '@/types/app' +import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../constants' -type AppsQuery = { - tagIDs?: string[] - keywords?: string - isCreatedByMe?: boolean +const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const +export type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number] + +const appListCategorySet = new Set(APP_LIST_CATEGORY_VALUES) + +export const isAppListCategory = (value: string): value is AppListCategory => { + return appListCategorySet.has(value) } -const normalizeKeywords = (value: string | null) => value || undefined - -function useAppsQueryState() { - const [urlQuery, setUrlQuery] = useQueryStates( - { - tagIDs: parseAsArrayOf(parseAsString, ';'), - keywords: parseAsString, - isCreatedByMe: parseAsBoolean, - }, - { - history: 'push', - }, - ) - - const query = useMemo(() => ({ - tagIDs: urlQuery.tagIDs ?? undefined, - keywords: normalizeKeywords(urlQuery.keywords), - isCreatedByMe: urlQuery.isCreatedByMe ?? false, - }), [urlQuery.isCreatedByMe, urlQuery.keywords, urlQuery.tagIDs]) - - const setQuery = useCallback((next: AppsQuery | ((prev: AppsQuery) => AppsQuery)) => { - const buildPatch = (patch: AppsQuery) => { - const result: Partial = {} - if ('tagIDs' in patch) - result.tagIDs = patch.tagIDs && patch.tagIDs.length > 0 ? patch.tagIDs : null - if ('keywords' in patch) - result.keywords = patch.keywords ? patch.keywords : null - if ('isCreatedByMe' in patch) - result.isCreatedByMe = patch.isCreatedByMe ? true : null - return result - } - - if (typeof next === 'function') { - setUrlQuery(prev => buildPatch(next({ - tagIDs: prev.tagIDs ?? undefined, - keywords: normalizeKeywords(prev.keywords), - isCreatedByMe: prev.isCreatedByMe ?? false, - }))) - return - } - - setUrlQuery(buildPatch(next)) - }, [setUrlQuery]) - - return useMemo(() => ({ query, setQuery }), [query, setQuery]) +const appListQueryParsers = { + category: parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) + .withDefault('all') + .withOptions({ history: 'push' }), + tagIDs: parseAsArrayOf(parseAsString, ';') + .withDefault([]) + .withOptions({ history: 'push' }), + keywords: parseAsString.withDefault('').withOptions({ + limitUrlUpdates: debounce(APP_LIST_SEARCH_DEBOUNCE_MS), + }), + isCreatedByMe: parseAsBoolean + .withDefault(false) + .withOptions({ history: 'push' }), } -export default useAppsQueryState +export function useAppsQueryState() { + const [query, setQuery] = useQueryStates(appListQueryParsers) + + const setCategory = useCallback((category: AppListCategory) => { + setQuery({ category }) + }, [setQuery]) + + const setKeywords = useCallback((keywords: string) => { + setQuery({ keywords }) + }, [setQuery]) + + const setTagIDs = useCallback((tagIDs: string[]) => { + setQuery({ tagIDs }) + }, [setQuery]) + + const setIsCreatedByMe = useCallback((isCreatedByMe: boolean) => { + setQuery({ isCreatedByMe }) + }, [setQuery]) + + return useMemo(() => ({ + query, + setCategory, + setKeywords, + setTagIDs, + setIsCreatedByMe, + }), [query, setCategory, setKeywords, setTagIDs, setIsCreatedByMe]) +} diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 0fd31dfb79..e2e8e737fc 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -4,8 +4,7 @@ import type { FC } from 'react' import type { AppListQuery } from '@/contract/console/apps' import { cn } from '@langgenius/dify-ui/cn' import { keepPreviousData, useInfiniteQuery, useSuspenseQuery } from '@tanstack/react-query' -import { useDebounceFn } from 'ahooks' -import { parseAsStringLiteral, useQueryState } from 'nuqs' +import { useDebounce } from 'ahooks' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' @@ -18,12 +17,13 @@ import { CheckModal } from '@/hooks/use-pay' import dynamic from '@/next/dynamic' import { consoleQuery } from '@/service/client' import { systemFeaturesQueryOptions } from '@/service/system-features' -import { AppModeEnum, AppModes } from '@/types/app' +import { AppModeEnum } from '@/types/app' import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' +import { APP_LIST_SEARCH_DEBOUNCE_MS } from './constants' import Empty from './empty' import Footer from './footer' -import useAppsQueryStateHook from './hooks/use-apps-query-state' +import { isAppListCategory, useAppsQueryState } from './hooks/use-apps-query-state' import { useDSLDragDrop } from './hooks/use-dsl-drag-drop' import { useWorkflowOnlineUsers } from './hooks/use-workflow-online-users' import NewAppCard from './new-app-card' @@ -35,18 +35,6 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro ssr: false, }) -const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const -type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number] -const appListCategorySet = new Set(APP_LIST_CATEGORY_VALUES) - -const isAppListCategory = (value: string): value is AppListCategory => { - return appListCategorySet.has(value) -} - -const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) - .withDefault('all') - .withOptions({ history: 'push' }) - type Props = { controlRefreshList?: number } @@ -56,28 +44,21 @@ const List: FC = ({ const { t } = useTranslation() const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() - const [activeTab, setActiveTab] = useQueryState( - 'category', - parseAsAppListCategory, - ) // eslint-disable-next-line react/use-state -- custom URL query hook, not React.useState - const appsQuery = useAppsQueryStateHook() - const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = appsQuery - const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe) - const [tagFilterValue, setTagFilterValue] = useState(tagIDs) - const [searchKeywords, setSearchKeywords] = useState(keywords) + const { + query: { category, tagIDs, keywords, isCreatedByMe }, + setCategory, + setKeywords, + setTagIDs, + setIsCreatedByMe, + } = useAppsQueryState() + const debouncedKeywords = useDebounce(keywords, { wait: APP_LIST_SEARCH_DEBOUNCE_MS }) const newAppCardRef = useRef(null) const containerRef = useRef(null) const [showTagManagementModal, setShowTagManagementModal] = useState(false) const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false) const [droppedDSLFile, setDroppedDSLFile] = useState() - const setKeywords = useCallback((keywords: string) => { - setQuery(prev => ({ ...prev, keywords })) - }, [setQuery]) - const setTagIDs = useCallback((tagIDs: string[]) => { - setQuery(prev => ({ ...prev, tagIDs })) - }, [setQuery]) const handleDSLFileDropped = useCallback((file: File) => { setDroppedDSLFile(file) @@ -93,11 +74,11 @@ const List: FC = ({ const appListQuery = useMemo(() => ({ page: 1, limit: 30, - name: searchKeywords, + name: debouncedKeywords, ...(tagIDs.length ? { tag_ids: tagIDs } : {}), ...(isCreatedByMe ? { is_created_by_me: isCreatedByMe } : {}), - ...(activeTab !== 'all' ? { mode: activeTab } : {}), - }), [activeTab, isCreatedByMe, searchKeywords, tagIDs]) + ...(category !== 'all' ? { mode: category } : {}), + }), [category, debouncedKeywords, isCreatedByMe, tagIDs]) const { data, @@ -177,27 +158,9 @@ const List: FC = ({ return () => observer?.disconnect() }, [isLoading, isFetchingNextPage, fetchNextPage, error, hasNextPage, isCurrentWorkspaceDatasetOperator]) - const { run: handleSearch } = useDebounceFn(() => { - setSearchKeywords(keywords) - }, { wait: 500 }) - const handleKeywordsChange = (value: string) => { - setKeywords(value) - handleSearch() - } - - const { run: handleTagsUpdate } = useDebounceFn(() => { - setTagIDs(tagFilterValue) - }, { wait: 500 }) - const handleTagsChange = (value: string[]) => { - setTagFilterValue(value) - handleTagsUpdate() - } - const handleCreatedByMeChange = useCallback(() => { - const newValue = !isCreatedByMe - setIsCreatedByMe(newValue) - setQuery(prev => ({ ...prev, isCreatedByMe: newValue })) - }, [isCreatedByMe, setQuery]) + setIsCreatedByMe(!isCreatedByMe) + }, [isCreatedByMe, setIsCreatedByMe]) const pages = useMemo(() => data?.pages ?? [], [data?.pages]) const apps = useMemo(() => pages.flatMap(({ data: pageApps }) => pageApps), [pages]) @@ -232,10 +195,10 @@ const List: FC = ({
{ if (isAppListCategory(nextValue)) - setActiveTab(nextValue) + setCategory(nextValue) }} options={options} /> @@ -246,14 +209,14 @@ const List: FC = ({ {t('showMyCreatedAppsOnly', { ns: 'app' })}
- setShowTagManagementModal(true)} /> + setShowTagManagementModal(true)} /> handleKeywordsChange(e.target.value)} - onClear={() => handleKeywordsChange('')} + onChange={e => setKeywords(e.target.value)} + onClear={() => setKeywords('')} /> @@ -267,7 +230,7 @@ const List: FC = ({ ref={newAppCardRef} isLoading={isLoadingCurrentWorkspace} onSuccess={refetch} - selectedAppType={activeTab} + selectedAppType={category} className={cn(!hasAnyApp && 'z-10')} /> )}