Merge branch 'main' into 4-27-app-deploy

This commit is contained in:
Stephen Zhou 2026-05-11 20:30:16 +08:00
commit 0a32344504
No known key found for this signature in database
33 changed files with 3031 additions and 447 deletions

View File

@ -55,9 +55,17 @@ Use this as the decision guide for React/TypeScript component structure. Existin
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
## Navigation, Effects, And Performance
## You Might Not Need An Effect
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
## Navigation And Performance
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
- Treat `useEffect` as a last resort. First try deriving values during render, moving event-driven work into handlers, or using existing hooks/APIs for persistence, subscriptions, media queries, timers, and DOM sync.
- Do not use `useEffect` directly in components. If unavoidable, encapsulate it in a purpose-built hook so the component consumes a declarative API.
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.

View File

@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# Ops trace retry configuration
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres

View File

@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings):
)
class OpsTraceConfig(BaseSettings):
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field(
description="Maximum retry attempts for transient ops trace provider dispatch failures.",
default=60,
)
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field(
description="Delay in seconds between transient ops trace provider dispatch retry attempts.",
default=5,
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
@ -1417,6 +1429,7 @@ class FeatureConfig(
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
OpsTraceConfig,
PositionConfig,
RagEtlConfig,
RepositoryConfig,

View File

@ -32,7 +32,7 @@ from core.app.entities.task_entities import (
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -166,6 +166,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
**extract_parent_trace_context_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args

View File

@ -15,6 +15,7 @@ from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.helper.trace_id_helper import ParentTraceContext
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -403,8 +404,13 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
parent_trace_context = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
extras = self._application_generate_entity.extras
external_trace_id = extras.get("external_trace_id")
parent_trace_context = extras.get("parent_trace_context")
if isinstance(parent_trace_context, ParentTraceContext):
parent_trace_context = parent_trace_context.model_dump(exclude_none=True)
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
@ -412,6 +418,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
parent_trace_context=parent_trace_context,
)
self._trace_manager.add_trace_task(trace_task)

View File

@ -3,6 +3,17 @@ import re
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError
class ParentTraceContext(BaseModel):
"""Typed parent trace context propagated from an outer workflow tool node."""
parent_workflow_run_id: StrictStr
parent_node_execution_id: StrictStr | None = None
model_config = ConfigDict(extra="forbid")
def is_valid_trace_id(trace_id: str) -> bool:
"""
@ -61,6 +72,30 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]):
return {}
def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]:
"""
Extract 'parent_trace_context' from args.
Returns a dict suitable for use in extras when both parent identifiers exist.
Returns an empty dict if the context is missing or incomplete.
"""
parent_trace_context = args.get("parent_trace_context")
if isinstance(parent_trace_context, ParentTraceContext):
context = parent_trace_context
elif isinstance(parent_trace_context, Mapping):
try:
context = ParentTraceContext.model_validate(parent_trace_context)
except ValidationError:
return {}
else:
return {}
if context.parent_node_execution_id is None:
return {}
return {"parent_trace_context": context}
def get_trace_id_from_otel_context() -> str | None:
"""
Retrieve the current trace ID from the active OpenTelemetry trace context.

View File

@ -5,6 +5,8 @@ from typing import Any, Union
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
from core.helper.trace_id_helper import ParentTraceContext
class BaseTraceInfo(BaseModel):
message_id: str | None = None
@ -51,8 +53,8 @@ class BaseTraceInfo(BaseModel):
def resolved_parent_context(self) -> tuple[str | None, str | None]:
"""Resolve cross-workflow parent linking from metadata.
Extracts typed parent IDs from the untyped ``parent_trace_context``
metadata dict (set by tool_node when invoking nested workflows).
Extracts typed parent IDs from the ``parent_trace_context`` metadata
payload (set by tool_node when invoking nested workflows).
Returns:
(trace_correlation_override, parent_span_id_source) where
@ -60,13 +62,18 @@ class BaseTraceInfo(BaseModel):
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if not isinstance(parent_ctx, dict):
if isinstance(parent_ctx, ParentTraceContext):
context = parent_ctx
elif isinstance(parent_ctx, Mapping):
try:
context = ParentTraceContext.model_validate(parent_ctx)
except ValueError:
return None, None
else:
return None, None
trace_override = parent_ctx.get("parent_workflow_run_id")
parent_span = parent_ctx.get("parent_node_execution_id")
return (
trace_override if isinstance(trace_override, str) else None,
parent_span if isinstance(parent_span, str) else None,
context.parent_workflow_run_id,
context.parent_node_execution_id,
)
@field_serializer("start_time", "end_time")

View File

@ -0,0 +1,22 @@
"""Core exceptions shared by ops trace dispatchers and trace providers.
Provider packages may raise these types to request generic task behavior, but
generic Celery tasks should not import provider-specific exception classes.
"""
class RetryableTraceDispatchError(RuntimeError):
"""Base class for transient trace dispatch failures that Celery may retry."""
class PendingTraceParentContextError(RetryableTraceDispatchError):
"""Raised when a nested trace arrives before its parent span context is available."""
parent_node_execution_id: str
def __init__(self, parent_node_execution_id: str) -> None:
self.parent_node_execution_id = parent_node_execution_id
super().__init__(
"Pending trace parent context for parent_node_execution_id="
f"{parent_node_execution_id}. Retry after the parent span context is published."
)

View File

@ -16,6 +16,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.helper.trace_id_helper import ParentTraceContext
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
BaseTracingConfig,
@ -52,6 +53,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | None:
if isinstance(parent_trace_context, ParentTraceContext):
return parent_trace_context.model_dump(exclude_none=True)
if isinstance(parent_trace_context, dict):
try:
return ParentTraceContext.model_validate(parent_trace_context).model_dump(exclude_none=True)
except ValueError:
return None
return None
class _AppTracingConfig(TypedDict, total=False):
enabled: bool
tracing_provider: str | None
@ -857,8 +869,9 @@ class TraceTask:
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
if dumped_parent_trace_context:
metadata["parent_trace_context"] = dumped_parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
@ -1371,13 +1384,14 @@ class TraceTask:
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
if dumped_parent_trace_context:
metadata["parent_trace_context"] = dumped_parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
if conversation_id and workflow_execution_id and not dumped_parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(

View File

@ -9,6 +9,7 @@ from sqlalchemy import select
from core.app.file_access import DatabaseFileAccessController
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
@ -36,6 +37,8 @@ class WorkflowTool(Tool):
Workflow tool.
"""
_parent_trace_context: ParentTraceContext | None
def __init__(
self,
workflow_app_id: str,
@ -54,6 +57,7 @@ class WorkflowTool(Tool):
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
self._parent_trace_context = None
super().__init__(entity=entity, runtime=runtime)
@ -94,11 +98,17 @@ class WorkflowTool(Tool):
self._latest_usage = LLMUsage.empty_usage()
generator_args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
if self._parent_trace_context:
generator_args.update(
extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context})
)
result = generator.generate(
app_model=app,
workflow=workflow,
user=user,
args={"inputs": tool_parameters, "files": files},
args=generator_args,
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
@ -194,7 +204,7 @@ class WorkflowTool(Tool):
:return: the new tool
"""
return self.__class__(
forked = self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
workflow_app_id=self.workflow_app_id,
@ -204,6 +214,24 @@ class WorkflowTool(Tool):
version=self.version,
label=self.label,
)
forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None
return forked
def set_parent_trace_context(
self,
*,
parent_workflow_run_id: str,
parent_node_execution_id: str,
) -> None:
"""Attach outer workflow trace context without exposing it as tool input."""
self._parent_trace_context = ParentTraceContext(
parent_workflow_run_id=parent_workflow_run_id,
parent_node_execution_id=parent_node_execution_id,
)
def clear_parent_trace_context(self) -> None:
"""Remove parent trace context before invoking this tool outside a nested workflow."""
self._parent_trace_context = None
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
"""

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.file_access import DatabaseFileAccessController
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.helper.trace_id_helper import ParentTraceContext
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
@ -358,6 +359,7 @@ class _WorkflowToolRuntimeBinding:
tool: Tool
conversation_id: str | None = None
parent_trace_context: ParentTraceContext | None = None
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
@ -398,7 +400,25 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
conversation_id = (
None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID)
)
return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id))
parent_trace_context: ParentTraceContext | None = None
if self._is_workflow_tool_provider(node_data):
outer_workflow_run_id = (
None
if variable_pool is None
else get_system_text(variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID)
)
if isinstance(outer_workflow_run_id, str) and isinstance(node_execution_id, str):
parent_trace_context = ParentTraceContext(
parent_workflow_run_id=outer_workflow_run_id,
parent_node_execution_id=node_execution_id,
)
return ToolRuntimeHandle(
raw=_WorkflowToolRuntimeBinding(
tool=tool_runtime,
conversation_id=conversation_id,
parent_trace_context=parent_trace_context,
)
)
def get_runtime_parameters(
self,
@ -422,6 +442,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
runtime_binding = self._binding_from_handle(tool_runtime)
tool = runtime_binding.tool
callback = DifyWorkflowCallbackHandler()
if runtime_binding.parent_trace_context and hasattr(tool, "set_parent_trace_context"):
tool.set_parent_trace_context(
parent_workflow_run_id=runtime_binding.parent_trace_context.parent_workflow_run_id,
parent_node_execution_id=runtime_binding.parent_trace_context.parent_node_execution_id,
)
elif hasattr(tool, "clear_parent_trace_context"):
tool.clear_parent_trace_context()
try:
messages = ToolEngine.generic_invoke(
@ -514,6 +541,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
credential_id=node_data.credential_id,
)
@staticmethod
def _is_workflow_tool_provider(node_data: ToolNodeData) -> bool:
return node_data.provider_type.value == CoreToolProviderType.WORKFLOW.value
def _adapt_messages(
self,
messages: Generator[CoreToolInvokeMessage, None, None],

View File

@ -1,19 +1,22 @@
"""Workflow comment models."""
from __future__ import annotations
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from models.base import TypeBase
from .account import Account
from .base import Base, gen_uuidv7_string
from .base import gen_uuidv7_string
from .engine import db
from .types import StringUUID
class WorkflowComment(Base):
class WorkflowComment(TypeBase):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
@ -42,27 +45,33 @@ class WorkflowComment(Base):
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(sa.Float)
position_y: Mapped[float] = mapped_column(sa.Float)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime, default=None)
resolved_by: Mapped[str | None] = mapped_column(StringUUID, default=None)
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
replies: Mapped[list[WorkflowCommentReply]] = relationship(
lambda: WorkflowCommentReply, back_populates="comment", cascade="all, delete-orphan", init=False
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
mentions: Mapped[list[WorkflowCommentMention]] = relationship(
lambda: WorkflowCommentMention, back_populates="comment", cascade="all, delete-orphan", init=False
)
@property
@ -131,7 +140,7 @@ class WorkflowComment(Base):
return participants
class WorkflowCommentReply(Base):
class WorkflowCommentReply(TypeBase):
"""Workflow comment reply model.
Attributes:
@ -149,18 +158,24 @@ class WorkflowCommentReply(Base):
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="replies", init=False)
@property
def created_by_account(self):
@ -174,7 +189,7 @@ class WorkflowCommentReply(Base):
self._created_by_account_cache = account
class WorkflowCommentMention(Base):
class WorkflowCommentMention(TypeBase):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
@ -194,18 +209,18 @@ class WorkflowCommentMention(Base):
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True, default=None
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="mentions", init=False)
reply: Mapped[WorkflowCommentReply | None] = relationship(lambda: WorkflowCommentReply, init=False)
@property
def mentioned_user_account(self):

View File

@ -1,9 +1,11 @@
import json
import logging
import os
import re
import traceback
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Any, Union, cast
from typing import Any, Protocol, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import (
@ -19,7 +21,7 @@ from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.attributes import exception_attributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace import Span, Status, StatusCode, get_current_span, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
@ -36,16 +38,106 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.exceptions import PendingTraceParentContextError
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
# This parent-span carrier store is intentionally Phoenix-local for the current
# nested workflow tracing feature. If other trace providers need the same
# cross-task parent restoration behavior, move the storage and retry signaling
# behind a core trace coordination interface instead of duplicating it here.
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS = 300
_TRACEPARENT_PATTERN = re.compile(
r"^(?P<version>[0-9a-f]{2})-(?P<trace_id>[0-9a-f]{32})-(?P<span_id>[0-9a-f]{16})-(?P<flags>[0-9a-f]{2})$"
)
def _phoenix_parent_span_redis_key(parent_node_execution_id: str) -> str:
"""Build the Redis key that stores a restorable Phoenix parent span carrier."""
return f"trace:phoenix:parent_span:{parent_node_execution_id}"
def _publish_parent_span_context(parent_node_execution_id: str, carrier: Mapping[str, str]) -> None:
"""Persist a tracecontext carrier so nested workflow spans can restore the tool span parent."""
redis_client.setex(
_phoenix_parent_span_redis_key(parent_node_execution_id),
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS,
safe_json_dumps(dict(carrier)),
)
def _resolve_published_parent_span_context(parent_node_execution_id: str) -> dict[str, str]:
"""Load a previously published tool-span carrier for nested workflow parenting."""
raw_carrier = redis_client.get(_phoenix_parent_span_redis_key(parent_node_execution_id))
if raw_carrier is None:
raise PendingTraceParentContextError(parent_node_execution_id)
if isinstance(raw_carrier, bytes):
raw_carrier = raw_carrier.decode("utf-8")
carrier = json.loads(raw_carrier)
if not isinstance(carrier, dict):
raise ValueError(
"Phoenix parent span context must be stored as a JSON object: "
f"parent_node_execution_id={parent_node_execution_id}"
)
normalized_carrier = {str(key): str(value) for key, value in carrier.items()}
if not normalized_carrier:
raise ValueError(
f"Phoenix parent span context payload is empty: parent_node_execution_id={parent_node_execution_id}"
)
traceparent = normalized_carrier.get("traceparent")
if not isinstance(traceparent, str):
raise ValueError(
"Phoenix parent span context payload is missing traceparent: "
f"parent_node_execution_id={parent_node_execution_id}"
)
traceparent_match = _TRACEPARENT_PATTERN.fullmatch(traceparent)
if traceparent_match is None:
raise ValueError(
"Phoenix parent span context payload has invalid traceparent format: "
f"parent_node_execution_id={parent_node_execution_id}"
)
if traceparent_match.group("version") == "ff":
raise ValueError(
"Phoenix parent span context payload has unsupported traceparent version: "
f"parent_node_execution_id={parent_node_execution_id}"
)
if traceparent_match.group("trace_id") == "0" * 32:
raise ValueError(
"Phoenix parent span context payload has zero trace_id in traceparent: "
f"parent_node_execution_id={parent_node_execution_id}"
)
if traceparent_match.group("span_id") == "0" * 16:
raise ValueError(
"Phoenix parent span context payload has zero span_id in traceparent: "
f"parent_node_execution_id={parent_node_execution_id}"
)
extracted_context = TraceContextTextMapPropagator().extract(carrier=normalized_carrier)
extracted_span_context = get_current_span(extracted_context).get_span_context()
if not extracted_span_context.is_valid or not extracted_span_context.is_remote:
raise ValueError(
"Phoenix parent span context payload could not be restored into a valid parent span: "
f"parent_node_execution_id={parent_node_execution_id}"
)
return normalized_carrier
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
@ -177,6 +269,246 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str:
"""Resolve the workflow session ID for Phoenix workflow spans."""
if trace_info.conversation_id:
return trace_info.conversation_id
parent_workflow_run_id, _ = _resolve_workflow_parent_context(trace_info)
if parent_workflow_run_id:
return parent_workflow_run_id
return trace_info.workflow_run_id
def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]:
"""Expose the typed parent context already resolved on the trace info."""
return trace_info.resolved_parent_context
def _resolve_workflow_root_trace_id(trace_info: WorkflowTraceInfo) -> str:
"""Resolve the canonical root trace ID for Phoenix workflow spans."""
trace_correlation_override, _ = _resolve_workflow_parent_context(trace_info)
return trace_correlation_override or trace_info.resolved_trace_id or trace_info.workflow_run_id
class _NodeExecutionIdentityLike(Protocol):
@property
def node_execution_id(self) -> str | None: ...
@property
def node_id(self) -> str: ...
@property
def predecessor_node_id(self) -> str | None: ...
class _NodeExecutionLike(_NodeExecutionIdentityLike, Protocol):
@property
def id(self) -> str: ...
@property
def node_type(self) -> str: ...
@property
def title(self) -> str | None: ...
@property
def inputs(self) -> Mapping[str, Any] | None: ...
@property
def process_data(self) -> Mapping[str, Any] | None: ...
@property
def outputs(self) -> Mapping[str, Any] | None: ...
@property
def status(self) -> WorkflowNodeExecutionStatus: ...
@property
def error(self) -> str | None: ...
@property
def elapsed_time(self) -> float | None: ...
@property
def metadata(self) -> Mapping[Any, Any] | None: ...
@property
def created_at(self) -> datetime | None: ...
_PHOENIX_STRUCTURED_NODE_TYPES = frozenset({"start", "end", "loop", "iteration"})
def _resolve_workflow_span_name(trace_info: WorkflowTraceInfo) -> str:
"""Resolve the Phoenix workflow span display name."""
workflow_run_id = trace_info.workflow_run_id.strip() if trace_info.workflow_run_id else ""
if workflow_run_id:
return f"{TraceTaskName.WORKFLOW_TRACE.value}_{workflow_run_id}"
return TraceTaskName.WORKFLOW_TRACE.value
def _build_node_title_by_id(trace_info: WorkflowTraceInfo) -> dict[str, str]:
"""Build an authoritative node-title index from the persisted workflow graph."""
workflow_data = trace_info.workflow_data
workflow_graph = getattr(workflow_data, "graph_dict", None)
if not isinstance(workflow_graph, Mapping):
workflow_graph = workflow_data.get("graph") if isinstance(workflow_data, Mapping) else None
if not isinstance(workflow_graph, Mapping):
return {}
graph_nodes = workflow_graph.get("nodes")
if not isinstance(graph_nodes, Sequence):
return {}
node_title_by_id: dict[str, str] = {}
for graph_node in graph_nodes:
if not isinstance(graph_node, Mapping):
continue
node_id = graph_node.get("id")
node_data = graph_node.get("data")
if not isinstance(node_id, str) or not isinstance(node_data, Mapping):
continue
node_title = node_data.get("title")
if isinstance(node_title, str) and node_title.strip():
node_title_by_id[node_id] = node_title.strip()
return node_title_by_id
def _resolve_workflow_node_span_name(
node_execution: _NodeExecutionLike,
node_title_by_id: Mapping[str, str] | None = None,
) -> str:
"""Resolve the Phoenix workflow node span display name."""
node_type = str(node_execution.node_type or "")
graph_node_title = None
if node_title_by_id is not None and isinstance(node_execution.node_id, str):
graph_node_title = node_title_by_id.get(node_execution.node_id)
node_title = graph_node_title or (node_execution.title.strip() if isinstance(node_execution.title, str) else "")
if node_title:
return f"{node_type}_{node_title}"
return node_type
def _get_node_execution_id(node_execution: _NodeExecutionIdentityLike) -> str:
"""Return the stable execution identifier for a workflow node execution."""
return str(getattr(node_execution, "id", None) or node_execution.node_execution_id)
def _build_execution_id_by_node_id(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
"""Index unique workflow graph node ids by execution id.
This Phoenix-local hierarchy reconstruction intentionally drops ambiguous
node ids instead of guessing based on repository order. That keeps parent
selection deterministic until upstream tracing exposes explicit parent span
data for repeated executions.
"""
execution_id_by_node_id: dict[str, str] = {}
ambiguous_node_ids: set[str] = set()
for node_execution in node_executions:
node_id = node_execution.node_id
if not isinstance(node_id, str):
continue
execution_id = _get_node_execution_id(node_execution)
if node_id in ambiguous_node_ids:
continue
existing_execution_id = execution_id_by_node_id.get(node_id)
if existing_execution_id is None:
execution_id_by_node_id[node_id] = execution_id
continue
if existing_execution_id != execution_id:
ambiguous_node_ids.add(node_id)
execution_id_by_node_id.pop(node_id, None)
return execution_id_by_node_id
def _build_graph_parent_index(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
"""Build an execution-id parent index from predecessor node ids."""
execution_id_by_node_id = _build_execution_id_by_node_id(node_executions)
graph_parent_index: dict[str, str] = {}
for node_execution in node_executions:
predecessor_node_id = node_execution.predecessor_node_id
if not isinstance(predecessor_node_id, str):
continue
predecessor_execution_id = execution_id_by_node_id.get(predecessor_node_id)
if predecessor_execution_id is not None:
execution_id = _get_node_execution_id(node_execution)
graph_parent_index[execution_id] = predecessor_execution_id
return graph_parent_index
def _resolve_structured_parent_execution_id(
node_execution: object, execution_id_by_node_id: Mapping[str, str]
) -> str | None:
"""Resolve Phoenix-local structured parents from loop/iteration node ids.
Any execution carrying ``iteration_id`` or ``loop_id`` belongs to an
enclosing structured node. When predecessor node ids are ambiguous because
the graph node repeats inside that structure, Phoenix can still keep the
child span under the enclosing loop/iteration span without relying on
execution-order heuristics.
"""
execution_metadata = getattr(node_execution, "execution_metadata_dict", None)
if not isinstance(execution_metadata, Mapping):
execution_metadata = getattr(node_execution, "metadata", None)
if not isinstance(execution_metadata, Mapping):
execution_metadata = {}
for enclosing_node_id in (
getattr(node_execution, "iteration_id", None),
getattr(node_execution, "loop_id", None),
execution_metadata.get("iteration_id"),
execution_metadata.get("loop_id"),
):
if not isinstance(enclosing_node_id, str):
continue
enclosing_execution_id = execution_id_by_node_id.get(enclosing_node_id)
if enclosing_execution_id is not None:
return enclosing_execution_id
return None
def _resolve_node_parent(
execution_id: str,
predecessor_execution_id: str | None,
structured_parent_execution_id: str | None,
span_by_execution_id: Mapping[str, Span],
graph_parent_index: Mapping[str, str],
workflow_span: Span,
) -> Span:
"""Resolve the parent span for a workflow node execution."""
if predecessor_execution_id is not None:
predecessor_span = span_by_execution_id.get(predecessor_execution_id)
if predecessor_span is not None:
return predecessor_span
graph_parent_execution_id = graph_parent_index.get(execution_id)
if graph_parent_execution_id is not None:
graph_parent_span = span_by_execution_id.get(graph_parent_execution_id)
if graph_parent_span is not None:
return graph_parent_span
if structured_parent_execution_id is not None:
structured_parent_span = span_by_execution_id.get(structured_parent_execution_id)
if structured_parent_span is not None:
return structured_parent_span
return workflow_span
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@ -189,6 +521,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
self.root_span_carriers: dict[str, dict[str, str]] = {}
self.carrier: dict[str, str] = {}
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
@ -235,13 +569,41 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
file_list=safe_json_dumps(file_list),
query=trace_info.query or "",
)
workflow_session_id = _resolve_workflow_session_id(trace_info)
parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info)
logger.info(
"[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s "
"parent_workflow_run_id=%s parent_node_execution_id=%s resolved_session_id=%s",
trace_info.workflow_run_id,
trace_info.conversation_id,
parent_workflow_run_id,
parent_node_execution_id,
workflow_session_id,
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
if parent_node_execution_id:
workflow_parent_carrier = _resolve_published_parent_span_context(parent_node_execution_id)
else:
root_trace_id = _resolve_workflow_root_trace_id(trace_info)
workflow_root_span_name: str | None = trace_info.workflow_run_id
if not isinstance(workflow_root_span_name, str) or not workflow_root_span_name.strip():
workflow_root_span_name = None
workflow_parent_carrier = self.ensure_root_span(
root_trace_id,
root_span_name=workflow_root_span_name,
root_span_attributes={
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
},
)
workflow_span_context = self.propagator.extract(carrier=workflow_parent_carrier)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
name=_resolve_workflow_span_name(trace_info),
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
@ -249,10 +611,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
SpanAttributes.SESSION_ID: workflow_session_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
context=workflow_span_context,
)
# Through workflow_run_id, get all_nodes_execution using repository
@ -276,16 +638,50 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
node_title_by_id = _build_node_title_by_id(trace_info)
execution_id_by_node_id = _build_execution_id_by_node_id(workflow_node_executions)
graph_parent_index = _build_graph_parent_index(workflow_node_executions)
node_execution_by_execution_id = {
_get_node_execution_id(node_execution): node_execution for node_execution in workflow_node_executions
}
span_by_execution_id: dict[str, Span] = {}
emitting_execution_ids: set[str] = set()
workflow_span_error: Exception | str | None = trace_info.error
try:
for node_execution in workflow_node_executions:
def emit_node_span(node_execution: _NodeExecutionLike) -> Span:
execution_id = _get_node_execution_id(node_execution)
existing_span = span_by_execution_id.get(execution_id)
if existing_span is not None:
return existing_span
graph_parent_execution_id = graph_parent_index.get(execution_id)
structured_parent_execution_id = _resolve_structured_parent_execution_id(
node_execution, execution_id_by_node_id
)
if execution_id not in emitting_execution_ids:
emitting_execution_ids.add(execution_id)
try:
for parent_execution_id in (graph_parent_execution_id, structured_parent_execution_id):
if parent_execution_id is None or parent_execution_id == execution_id:
continue
if parent_execution_id in span_by_execution_id:
continue
parent_node_execution = node_execution_by_execution_id.get(parent_execution_id)
if parent_node_execution is not None:
emit_node_span(parent_node_execution)
finally:
emitting_execution_ids.discard(execution_id)
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
inputs_value = node_execution.inputs or {}
outputs_value = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
elapsed_time = node_execution.elapsed_time or 0
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = node_execution.process_data or {}
@ -324,9 +720,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
workflow_span_context = set_span_in_context(workflow_span)
parent_span = _resolve_node_parent(
execution_id=execution_id,
predecessor_execution_id=None,
structured_parent_execution_id=structured_parent_execution_id,
span_by_execution_id=span_by_execution_id,
graph_parent_index=graph_parent_index,
workflow_span=workflow_span,
)
workflow_span_context = set_span_in_context(parent_span)
node_span = self.tracer.start_span(
name=node_execution.node_type,
name=_resolve_workflow_node_span_name(node_execution, node_title_by_id),
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
@ -334,13 +738,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
SpanAttributes.SESSION_ID: workflow_session_id or "",
},
start_time=datetime_to_nanos(created_at),
context=workflow_span_context,
)
span_by_execution_id[execution_id] = node_span
node_span_error: Exception | str | None = None
try:
if node_execution.node_type == "tool":
parent_span_carrier: dict[str, str] = {}
with use_span(node_span, end_on_exit=False):
self.propagator.inject(carrier=parent_span_carrier)
_publish_parent_span_context(execution_id, parent_span_carrier)
if node_execution.node_type == "llm":
llm_attributes: dict[str, Any] = {
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
@ -362,17 +773,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
)
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
except Exception as e:
node_span_error = e
raise
finally:
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
if node_span_error is not None:
set_span_status(node_span, node_span_error)
elif node_execution.status == WorkflowNodeExecutionStatus.FAILED:
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
return node_span
for node_execution in workflow_node_executions:
emit_node_span(node_execution)
except Exception as e:
workflow_span_error = e
raise
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
set_span_status(workflow_span, workflow_span_error)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
@ -735,22 +1155,39 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def ensure_root_span(self, dify_trace_id: str | None):
def ensure_root_span(
self,
dify_trace_id: str | None,
*,
root_span_name: str | None = None,
root_span_attributes: Mapping[str, AttributeValue] | None = None,
):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
trace_key = str(dify_trace_id)
if trace_key not in self.dify_trace_ids:
carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
span_name = root_span_name.strip() if isinstance(root_span_name, str) and root_span_name.strip() else "Dify"
root_span_attributes_dict: dict[str, AttributeValue] = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
"dify_project_name": str(self.project),
"dify_trace_id": trace_key,
}
if root_span_attributes:
root_span_attributes_dict.update(root_span_attributes)
root_span = self.tracer.start_span(name=span_name, attributes=root_span_attributes_dict)
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
self.propagator.inject(carrier=carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
self.dify_trace_ids.add(trace_key)
self.root_span_carriers[trace_key] = carrier
self.carrier = self.root_span_carriers[trace_key]
return self.carrier
def api_check(self):
try:

View File

@ -1,36 +0,0 @@
from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from openinference.semconv.trace import OpenInferenceSpanKindValues
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
class TestGetNodeSpanKind:
"""Tests for _get_node_span_kind helper."""
def test_all_node_types_are_mapped_correctly(self):
"""Ensure every built-in node type is mapped to the correct span kind."""
# Mappings for node types that have a specialised span kind.
special_mappings = {
BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER,
BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL,
BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT,
}
# Test that every built-in node type is mapped to the correct span kind.
# Node types not in `special_mappings` should default to CHAIN.
for node_type in BUILT_IN_NODE_TYPES:
expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN)
actual_span_kind = _get_node_span_kind(node_type)
assert actual_span_kind == expected_span_kind, (
f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected."
)
def test_unknown_string_defaults_to_chain(self):
"""An unrecognised node type string should still return CHAIN."""
assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN
def test_stale_dataset_retrieval_not_in_mapping(self):
"""The old 'dataset_retrieval' string was never a valid NodeType value;
make sure it is not present in the mapping dictionary."""
assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND

View File

@ -1,11 +1,31 @@
"""
Celery task for asynchronous ops trace dispatch.
Trace providers may report explicitly retryable dispatch failures through the
core retryable exception contract. The task preserves the payload file only
when Celery accepts the retry request; successful dispatches and terminal
failures clean up the stored payload.
One concrete producer today is Phoenix nested workflow tracing. The outer
workflow tool span publishes a restorable parent span context asynchronously,
while the nested workflow trace may be picked up by Celery first. In that
ordering window, the provider raises a retryable core exception instead of
dropping the trace or emitting it under the wrong parent. The task intentionally
does not know that the provider is Phoenix; it only honors the core retryable
dispatch contract.
"""
import json
import logging
from celery import shared_task
from celery.exceptions import Retry
from flask import current_app
from configs import dify_config
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
from core.ops.entities.trace_entity import trace_info_info_map
from core.ops.exceptions import RetryableTraceDispatchError
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -14,9 +34,17 @@ from models.workflow import WorkflowRun
logger = logging.getLogger(__name__)
_RETRYABLE_TRACE_DISPATCH_LIMIT = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES
_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS
@shared_task(queue="ops_trace")
def process_trace_tasks(file_info):
@shared_task(
queue="ops_trace",
bind=True,
max_retries=_RETRYABLE_TRACE_DISPATCH_LIMIT,
default_retry_delay=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS,
)
def process_trace_tasks(self, file_info):
"""
Async process trace tasks
Usage: process_trace_tasks.delay(tasks_data)
@ -29,6 +57,7 @@ def process_trace_tasks(file_info):
file_data = json.loads(storage.load(file_path))
trace_info = file_data.get("trace_info")
trace_info_type = file_data.get("trace_info_type")
enterprise_trace_dispatched = bool(file_data.get("_enterprise_trace_dispatched"))
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
if trace_info.get("message_data"):
@ -38,6 +67,8 @@ def process_trace_tasks(file_info):
if trace_info.get("documents"):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
should_delete_file = True
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
@ -45,30 +76,66 @@ def process_trace_tasks(file_info):
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
if is_ee_telemetry_enabled() and not enterprise_trace_dispatched:
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.exception("Enterprise trace failed for app_id: %s", app_id)
else:
file_data["_enterprise_trace_dispatched"] = True
enterprise_trace_dispatched = True
if trace_instance:
with current_app.app_context():
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except RetryableTraceDispatchError as e:
# Retryable dispatch failures represent a transient provider-side
# ordering gap, not corrupt payload data. Keep the payload only after
# Celery accepts the retry request; otherwise this attempt becomes a
# terminal failure and the stored file is cleaned up in `finally`.
#
# Enterprise telemetry runs before provider dispatch. If it already ran
# and provider dispatch asks for a retry, persist that private flag so
# the next attempt does not emit the same enterprise trace twice.
if self.request.retries >= _RETRYABLE_TRACE_DISPATCH_LIMIT:
logger.exception("Retryable trace dispatch budget exhausted, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
else:
logger.warning(
"Retryable trace dispatch failure, scheduling retry %s/%s for app_id %s: %s",
self.request.retries + 1,
_RETRYABLE_TRACE_DISPATCH_LIMIT,
app_id,
e,
)
try:
if enterprise_trace_dispatched:
storage.save(file_path, json.dumps(file_data).encode("utf-8"))
raise self.retry(exc=e, countdown=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS)
except Retry:
should_delete_file = False
raise
except Exception:
logger.exception("Failed to schedule trace dispatch retry, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
except Exception as e:
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
finally:
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)
if should_delete_file:
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@ -1,3 +1,4 @@
import contextlib
from types import SimpleNamespace
from unittest.mock import MagicMock
@ -24,6 +25,76 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_generate_includes_parent_trace_context_in_extras(monkeypatch):
generator = WorkflowAppGenerator()
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator._bind_file_access_scope",
lambda *args, **kwargs: contextlib.nullcontext(),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config",
lambda *args, **kwargs: SimpleNamespace(
app_id="app-1", tenant_id="tenant-1", workflow_id="workflow-1", variables=[]
),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.file_factory.build_from_mappings", lambda *args, **kwargs: []
)
monkeypatch.setattr("core.app.apps.workflow.app_generator.TraceQueueManager", MagicMock())
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
MagicMock(return_value=MagicMock()),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
MagicMock(return_value=MagicMock()),
)
monkeypatch.setattr("core.app.apps.workflow.app_generator.db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(generator, "_prepare_user_inputs", lambda *, user_inputs, **kwargs: user_inputs)
captured = {}
def fake_workflow_app_generate_entity(**kwargs):
captured["workflow_app_generate_entity_kwargs"] = kwargs
return SimpleNamespace(**kwargs)
def fake_generate(**kwargs):
captured["application_generate_entity"] = kwargs["application_generate_entity"]
return {"data": {}}
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateEntity", fake_workflow_app_generate_entity
)
monkeypatch.setattr(generator, "_generate", fake_generate)
result = generator.generate(
app_model=SimpleNamespace(tenant_id="tenant-1", id="app-1"),
workflow=SimpleNamespace(features_dict={}),
user=SimpleNamespace(id="user-1", session_id="session-1"),
args={
"inputs": {"query": "hello"},
"files": [],
"external_trace_id": "trace-1",
"parent_trace_context": {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
},
},
invoke_from="service-api",
streaming=False,
call_depth=0,
)
assert result == {"data": {}}
extras = captured["workflow_app_generate_entity_kwargs"]["extras"]
assert extras["external_trace_id"] == "trace-1"
assert extras["parent_trace_context"].model_dump() == {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
}
def test_resume_delegates_to_generate(mocker: MockerFixture):
generator = WorkflowAppGenerator()
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")

View File

@ -7,6 +7,7 @@ import pytest
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
from core.workflow.system_variables import SystemVariableKey, build_system_variables
from graphon.entities import WorkflowNodeExecution
from graphon.entities.pause_reason import SchedulingPause
@ -217,6 +218,59 @@ class TestWorkflowPersistenceLayer:
assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED
assert trace_tasks
def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch):
trace_tasks: list[TraceTask] = []
trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task))
layer, _, _, _ = _make_layer(
extras={
"external_trace_id": "trace",
"parent_trace_context": {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
},
},
trace_manager=trace_manager,
)
layer._handle_graph_run_started()
captured: dict[str, object] = {}
def fake_workflow_trace(
self: TraceTask,
*,
workflow_run_id: str | None,
conversation_id: str | None,
user_id: str | None,
total_tokens_override: int | None = None,
):
captured["trace_type"] = self.trace_type
captured["external_trace_id"] = self.kwargs.get("external_trace_id")
captured["parent_trace_context"] = self.kwargs.get("parent_trace_context")
captured["workflow_run_id"] = workflow_run_id
return {"ok": True}
monkeypatch.setattr(TraceTask, "workflow_trace", fake_workflow_trace)
layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True}))
assert trace_tasks
trace_task = trace_tasks[0]
assert trace_task.trace_type == TraceTaskName.WORKFLOW_TRACE
assert trace_task.kwargs["external_trace_id"] == "trace"
assert trace_task.kwargs["parent_trace_context"] == {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
}
trace_task.execute()
assert captured["trace_type"] == TraceTaskName.WORKFLOW_TRACE
assert captured["external_trace_id"] == "trace"
assert captured["parent_trace_context"] == {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
}
def test_handle_graph_run_aborted_sets_status(self):
layer, exec_repo, _, _ = _make_layer()
layer._handle_graph_run_started()

View File

@ -1,6 +1,12 @@
import pytest
from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
from core.helper.trace_id_helper import (
ParentTraceContext,
extract_external_trace_id_from_args,
extract_parent_trace_context_from_args,
get_external_trace_id,
is_valid_trace_id,
)
class DummyRequest:
@ -84,3 +90,92 @@ class TestTraceIdHelper:
def test_extract_external_trace_id_from_args(self, args, expected):
"""Test extraction of external_trace_id from args mapping"""
assert extract_external_trace_id_from_args(args) == expected
@pytest.mark.parametrize(
("args", "expected"),
[
(
{
"parent_trace_context": {
"parent_workflow_run_id": "workflow-run-1",
"parent_node_execution_id": "node-execution-1",
}
},
{
"parent_trace_context": ParentTraceContext(
parent_workflow_run_id="workflow-run-1",
parent_node_execution_id="node-execution-1",
)
},
),
(
{
"parent_trace_context": {
"parent_workflow_run_id": "workflow-run-1",
}
},
{},
),
(
{
"parent_trace_context": {
"parent_node_execution_id": "node-execution-1",
}
},
{},
),
(
{
"parent_trace_context": {
"parent_workflow_run_id": 123,
"parent_node_execution_id": "node-execution-1",
}
},
{},
),
(
{
"parent_trace_context": {
"parent_workflow_run_id": "workflow-run-1",
"parent_node_execution_id": None,
}
},
{},
),
({}, {}),
],
)
def test_extract_parent_trace_context_from_args(self, args, expected):
"""Test extraction of parent_trace_context from args mapping"""
assert extract_parent_trace_context_from_args(args) == expected
def test_extract_parent_trace_context_returns_typed_context(self):
"""Parent trace context is parsed into a Pydantic value object."""
result = extract_parent_trace_context_from_args(
{
"parent_trace_context": {
"parent_workflow_run_id": "workflow-run-1",
"parent_node_execution_id": "node-execution-1",
}
}
)
assert result == {
"parent_trace_context": ParentTraceContext(
parent_workflow_run_id="workflow-run-1",
parent_node_execution_id="node-execution-1",
)
}
def test_extract_parent_trace_context_rejects_incomplete_typed_context(self):
"""Typed parent trace context follows the same completeness rule as raw mappings."""
result = extract_parent_trace_context_from_args(
{
"parent_trace_context": ParentTraceContext(
parent_workflow_run_id="workflow-run-1",
parent_node_execution_id=None,
)
}
)
assert result == {}

View File

@ -147,6 +147,142 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke
assert call_kwargs["pause_state_config"] is None
def test_workflow_tool_passes_parent_trace_context_from_runtime(monkeypatch: pytest.MonkeyPatch):
"""Ensure nested workflow runtime metadata is forwarded as parent trace context."""
tool = _build_tool()
tool.set_parent_trace_context(
parent_workflow_run_id="outer-workflow-run-1",
parent_node_execution_id="outer-node-execution-1",
)
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
generate_mock = MagicMock(return_value={"data": {}})
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
list(tool.invoke("test_user", {}))
call_kwargs = generate_mock.call_args.kwargs
assert call_kwargs["args"]["parent_trace_context"].model_dump() == {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
}
def test_workflow_tool_keeps_user_inputs_named_like_trace_runtime_keys(monkeypatch: pytest.MonkeyPatch):
"""Ensure private trace context does not overwrite same-named workflow inputs."""
tool = _build_tool()
tool.entity.parameters = [
ToolParameter.get_simple_instance(
name="outer_workflow_run_id",
llm_description="User workflow input",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
),
ToolParameter.get_simple_instance(
name="outer_node_execution_id",
llm_description="User node input",
typ=ToolParameter.ToolParameterType.STRING,
required=False,
),
]
tool.set_parent_trace_context(
parent_workflow_run_id="outer-workflow-run-1",
parent_node_execution_id="outer-node-execution-1",
)
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
generate_mock = MagicMock(return_value={"data": {}})
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
list(
tool.invoke(
"test_user",
{
"outer_workflow_run_id": "user-workflow-input",
"outer_node_execution_id": "user-node-input",
},
)
)
call_kwargs = generate_mock.call_args.kwargs
assert call_kwargs["args"]["inputs"]["outer_workflow_run_id"] == "user-workflow-input"
assert call_kwargs["args"]["inputs"]["outer_node_execution_id"] == "user-node-input"
assert call_kwargs["args"]["parent_trace_context"].model_dump() == {
"parent_workflow_run_id": "outer-workflow-run-1",
"parent_node_execution_id": "outer-node-execution-1",
}
def test_workflow_tool_can_clear_parent_trace_context(monkeypatch: pytest.MonkeyPatch):
"""Ensure reused WorkflowTool instances do not keep stale parent trace context."""
tool = _build_tool()
tool.set_parent_trace_context(
parent_workflow_run_id="outer-workflow-run-1",
parent_node_execution_id="outer-node-execution-1",
)
tool.clear_parent_trace_context()
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
generate_mock = MagicMock(return_value={"data": {}})
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
list(tool.invoke("test_user", {}))
call_kwargs = generate_mock.call_args.kwargs
assert "parent_trace_context" not in call_kwargs["args"]
@pytest.mark.parametrize(
"runtime_parameters",
[
{},
{"outer_workflow_run_id": "outer-workflow-run-1"},
{"outer_node_execution_id": "outer-node-execution-1"},
{"outer_workflow_run_id": None, "outer_node_execution_id": None},
],
)
def test_workflow_tool_omits_parent_trace_context_when_runtime_is_incomplete(
monkeypatch: pytest.MonkeyPatch,
runtime_parameters: dict[str, Any],
):
"""Ensure incomplete runtime metadata does not leak parent trace context into generator args."""
tool = _build_tool()
tool.runtime.runtime_parameters = runtime_parameters
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
generate_mock = MagicMock(return_value={"data": {}})
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
list(tool.invoke("test_user", {}))
call_kwargs = generate_mock.call_args.kwargs
assert "parent_trace_context" not in call_kwargs["args"]
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should generate variable messages when there are outputs"""
tool = _build_tool()

View File

@ -15,9 +15,9 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import StreamChunkEvent, StreamCompletedEvent
from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.runtime import GraphRuntimeState
from graphon.variables.segments import ArrayFileSegment
from tests.workflow_test_utils import build_test_graph_init_params
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
if TYPE_CHECKING: # pragma: no cover - imported for type checking only
from graphon.nodes.tool.tool_node import ToolNode
@ -106,7 +106,7 @@ def tool_node(monkeypatch) -> ToolNode:
call_depth=0,
)
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
variable_pool = build_test_variable_pool(variables=build_system_variables(user_id="user-id"))
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
@ -234,3 +234,22 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
files_segment = completed_events[0].node_run_result.outputs["files"]
assert isinstance(files_segment, ArrayFileSegment)
assert files_segment.value == [file_obj]
def test_tool_node_passes_node_execution_id_when_runtime_accepts_it(tool_node: ToolNode):
runtime_handle = ToolRuntimeHandle(raw=object())
tool_node._runtime.get_runtime = MagicMock(return_value=runtime_handle)
tool_node.ensure_execution_id = MagicMock(return_value="node-execution-id")
result = tool_node._get_tool_runtime(
variable_pool=tool_node.graph_runtime_state.variable_pool,
node_execution_id="node-execution-id",
)
assert result is runtime_handle
tool_node._runtime.get_runtime.assert_called_once_with(
node_id="node-instance",
node_data=tool_node.node_data,
variable_pool=tool_node.graph_runtime_state.variable_pool,
node_execution_id="node-execution-id",
)

View File

@ -147,6 +147,69 @@ def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: Dify
assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN
def test_get_runtime_stores_parent_trace_context_for_workflow_tools(
runtime: DifyToolNodeRuntime,
) -> None:
variable_pool: VariablePool = build_test_variable_pool(
variables=build_system_variables(
conversation_id="conversation-id",
workflow_execution_id="workflow-run-id",
)
)
workflow_runtime = MagicMock()
workflow_runtime.runtime.runtime_parameters = {}
node_data = ToolNodeData.model_validate(
{
"type": "tool",
"title": "Tool",
"provider_id": "provider",
"provider_type": ToolProviderType.WORKFLOW,
"provider_name": "provider",
"tool_name": "lookup",
"tool_label": "Lookup",
"tool_configurations": {},
"tool_parameters": {},
}
)
with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_runtime):
tool_runtime = runtime.get_runtime(
node_id="node-id",
node_data=node_data,
variable_pool=variable_pool,
node_execution_id="node-execution-id",
)
assert tool_runtime.raw.parent_trace_context.model_dump() == {
"parent_workflow_run_id": "workflow-run-id",
"parent_node_execution_id": "node-execution-id",
}
assert workflow_runtime.runtime.runtime_parameters == {}
def test_get_runtime_leaves_non_workflow_tool_runtime_parameters_unchanged(
runtime: DifyToolNodeRuntime,
) -> None:
variable_pool: VariablePool = build_test_variable_pool(
variables=build_system_variables(
conversation_id="conversation-id",
workflow_execution_id="workflow-run-id",
)
)
builtin_runtime = MagicMock()
builtin_runtime.runtime.runtime_parameters = {}
with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=builtin_runtime):
runtime.get_runtime(
node_id="node-id",
node_data=_build_tool_node_data(),
variable_pool=variable_pool,
node_execution_id="node-execution-id",
)
assert builtin_runtime.runtime.runtime_parameters == {}
def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None:
tool_runtime = ToolRuntimeHandle(
raw=SimpleNamespace(

View File

@ -316,6 +316,81 @@ def test_dify_tool_file_manager_delegates_file_generator_lookup(monkeypatch: pyt
get_file_generator.assert_called_once_with("tool-file-id")
def test_dify_tool_node_runtime_injects_outer_workflow_run_id_for_workflow_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={}))
get_runtime = MagicMock(return_value=runtime_tool)
monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime)
monkeypatch.setattr(
node_runtime,
"get_system_text",
lambda _pool, key: (
"outer-workflow-run-id" if key == node_runtime.SystemVariableKey.WORKFLOW_EXECUTION_ID else None
),
)
runtime = node_runtime.DifyToolNodeRuntime(_build_run_context())
node_data = ToolNodeData(
title="Workflow Tool Node",
desc=None,
provider_id="workflow-provider-id",
provider_type=ToolProviderType.WORKFLOW,
provider_name="workflow-provider",
tool_name="workflow-tool",
tool_label="Workflow Tool",
tool_configurations={},
tool_parameters={},
)
handle = runtime.get_runtime(
node_id="tool-node",
node_data=node_data,
variable_pool=object(),
node_execution_id="node-execution-id",
)
assert handle.raw.tool is runtime_tool
assert handle.raw.parent_trace_context.model_dump() == {
"parent_workflow_run_id": "outer-workflow-run-id",
"parent_node_execution_id": "node-execution-id",
}
assert runtime_tool.runtime.runtime_parameters == {}
get_runtime.assert_called_once()
def test_dify_tool_node_runtime_does_not_inject_outer_workflow_run_id_for_non_workflow_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={}))
get_runtime = MagicMock(return_value=runtime_tool)
monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime)
monkeypatch.setattr(node_runtime, "get_system_text", lambda _pool, _key: None)
runtime = node_runtime.DifyToolNodeRuntime(_build_run_context())
node_data = ToolNodeData(
title="Builtin Tool Node",
desc=None,
provider_id="builtin-provider-id",
provider_type=ToolProviderType.BUILT_IN,
provider_name="builtin-provider",
tool_name="builtin-tool",
tool_label="Builtin Tool",
tool_configurations={},
tool_parameters={},
)
handle = runtime.get_runtime(
node_id="tool-node",
node_data=node_data,
variable_pool=object(),
)
assert handle.raw.tool is runtime_tool
assert "outer_workflow_run_id" not in runtime_tool.runtime.runtime_parameters
get_runtime.assert_called_once()
def test_dify_human_input_runtime_builds_debug_repository(monkeypatch: pytest.MonkeyPatch) -> None:
repository = MagicMock()
repository_cls = MagicMock(return_value=repository)

View File

@ -4,7 +4,15 @@ from models.comment import WorkflowComment, WorkflowCommentMention, WorkflowComm
def test_workflow_comment_account_properties_and_cache() -> None:
comment = WorkflowComment(created_by="user-1", resolved_by="user-2", content="hello", position_x=1, position_y=2)
comment = WorkflowComment(
created_by="user-1",
resolved_by="user-2",
content="hello",
position_x=1,
position_y=2,
tenant_id="xxx",
app_id="yyy",
)
created_account = Mock(id="user-1")
resolved_account = Mock(id="user-2")
@ -21,6 +29,8 @@ def test_workflow_comment_account_properties_and_cache() -> None:
get_mock.assert_not_called()
comment_without_resolver = WorkflowComment(
tenant_id="xxx",
app_id="yyy",
created_by="user-1",
resolved_by=None,
content="hello",
@ -37,7 +47,15 @@ def test_workflow_comment_counts_and_participants() -> None:
reply_2 = WorkflowCommentReply(comment_id="comment-1", content="reply-2", created_by="user-2")
mention_1 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3")
mention_2 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-4")
comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2)
comment = WorkflowComment(
created_by="user-1",
resolved_by=None,
content="hello",
position_x=1,
position_y=2,
tenant_id="xxx",
app_id="yyy",
)
comment.replies = [reply_1, reply_2]
comment.mentions = [mention_1, mention_2]
@ -63,7 +81,15 @@ def test_workflow_comment_counts_and_participants() -> None:
def test_workflow_comment_participants_use_cached_accounts() -> None:
reply = WorkflowCommentReply(comment_id="comment-1", content="reply-1", created_by="user-2")
mention = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3")
comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2)
comment = WorkflowComment(
created_by="user-1",
resolved_by=None,
content="hello",
position_x=1,
position_y=2,
tenant_id="xxx",
app_id="yyy",
)
comment.replies = [reply]
comment.mentions = [mention]

View File

@ -0,0 +1,301 @@
import json
import sys
from contextlib import contextmanager
from types import ModuleType
from unittest.mock import MagicMock, patch
import pytest
from celery.exceptions import Retry
from core.ops.entities.config_entity import OPS_TRACE_FAILED_KEY
from core.ops.exceptions import RetryableTraceDispatchError
from tasks.ops_trace_task import process_trace_tasks
@contextmanager
def fake_app_context():
yield
class FakeCurrentApp:
def app_context(self):
return fake_app_context()
def _install_trace_manager(
trace_instance: MagicMock,
*,
enterprise_enabled: bool = False,
enterprise_trace_cls: MagicMock | None = None,
) -> dict[str, ModuleType]:
ops_trace_manager_module = ModuleType("core.ops.ops_trace_manager")
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id: str) -> MagicMock:
return trace_instance
telemetry_module = ModuleType("extensions.ext_enterprise_telemetry")
telemetry_module.is_enabled = lambda: enterprise_enabled
ops_trace_manager_module.OpsTraceManager = StubOpsTraceManager
modules = {
"core.ops.ops_trace_manager": ops_trace_manager_module,
"extensions.ext_enterprise_telemetry": telemetry_module,
}
if enterprise_trace_cls is not None:
enterprise_module = ModuleType("enterprise")
enterprise_telemetry_module = ModuleType("enterprise.telemetry")
enterprise_trace_module = ModuleType("enterprise.telemetry.enterprise_trace")
enterprise_trace_module.EnterpriseOtelTrace = enterprise_trace_cls
modules.update(
{
"enterprise": enterprise_module,
"enterprise.telemetry": enterprise_telemetry_module,
"enterprise.telemetry.enterprise_trace": enterprise_trace_module,
}
)
return modules
def _make_payload() -> str:
return json.dumps({"trace_info": {}, "trace_info_type": None})
def _decode_saved_payload(payload: bytes | str) -> dict[str, object]:
if isinstance(payload, bytes):
payload = payload.decode("utf-8")
return json.loads(payload)
def _retryable_dispatch_error() -> RetryableTraceDispatchError:
return RetryableTraceDispatchError("transient trace dispatch failure")
def _run_task(file_info: dict[str, str], retries: int = 0) -> None:
process_trace_tasks.push_request(retries=retries)
try:
process_trace_tasks.run(file_info)
finally:
process_trace_tasks.pop_request()
def test_process_trace_tasks_retries_retryable_dispatch_failure_and_preserves_payload():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
retry_error = Retry()
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
pytest.raises(Retry),
):
_run_task(file_info)
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_not_called()
mock_incr.assert_not_called()
def test_process_trace_tasks_marks_enterprise_trace_dispatched_before_retryable_dispatch_retry():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
retry_error = Retry()
enterprise_tracer = MagicMock()
enterprise_trace_cls = MagicMock(return_value=enterprise_tracer)
with (
patch.dict(
sys.modules,
_install_trace_manager(
trace_instance,
enterprise_enabled=True,
enterprise_trace_cls=enterprise_trace_cls,
),
),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.save") as mock_save,
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
pytest.raises(Retry),
):
_run_task(file_info)
enterprise_tracer.trace.assert_called_once_with({})
saved_path, saved_payload = mock_save.call_args.args
assert saved_path == "ops_trace/app-id/file-id.json"
assert _decode_saved_payload(saved_payload)["_enterprise_trace_dispatched"] is True
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_not_called()
mock_incr.assert_not_called()
def test_process_trace_tasks_does_not_mark_failed_enterprise_trace_as_dispatched_before_retry():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
retry_error = Retry()
enterprise_tracer = MagicMock()
enterprise_tracer.trace.side_effect = RuntimeError("enterprise trace failed")
enterprise_trace_cls = MagicMock(return_value=enterprise_tracer)
with (
patch.dict(
sys.modules,
_install_trace_manager(
trace_instance,
enterprise_enabled=True,
enterprise_trace_cls=enterprise_trace_cls,
),
),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.save") as mock_save,
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
pytest.raises(Retry),
):
_run_task(file_info)
enterprise_tracer.trace.assert_called_once_with({})
mock_save.assert_not_called()
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_not_called()
mock_incr.assert_not_called()
def test_process_trace_tasks_skips_enterprise_trace_when_retry_payload_was_already_dispatched():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
enterprise_trace_cls = MagicMock()
payload = json.dumps({"trace_info": {}, "trace_info_type": None, "_enterprise_trace_dispatched": True})
with (
patch.dict(
sys.modules,
_install_trace_manager(
trace_instance,
enterprise_enabled=True,
enterprise_trace_cls=enterprise_trace_cls,
),
),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=payload),
patch("tasks.ops_trace_task.storage.save") as mock_save,
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
):
_run_task(file_info)
enterprise_trace_cls.assert_not_called()
trace_instance.trace.assert_called_once_with({})
mock_save.assert_not_called()
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_not_called()
def test_process_trace_tasks_default_retry_window_covers_parent_span_context_ttl():
assert process_trace_tasks.max_retries * process_trace_tasks.default_retry_delay >= 300
def test_process_trace_tasks_deletes_payload_on_success():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
):
_run_task(file_info)
trace_instance.trace.assert_called_once_with({})
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_not_called()
def test_process_trace_tasks_deletes_payload_and_counts_terminal_failure():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
trace_instance.trace.side_effect = RuntimeError("trace failed")
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
):
_run_task(file_info)
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
def test_process_trace_tasks_treats_retry_enqueue_failure_as_terminal_failure():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
retry_enqueue_error = RuntimeError("retry enqueue failed")
trace_instance.trace.side_effect = pending_error
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_enqueue_error) as mock_retry,
):
_run_task(file_info)
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
def test_process_trace_tasks_deletes_payload_and_counts_exhausted_retryable_dispatch_failure():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry") as mock_retry,
):
_run_task(file_info, retries=process_trace_tasks.max_retries)
mock_retry.assert_not_called()
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")

View File

@ -1,5 +1,6 @@
# ------------------------------------------------------------------
# Essential defaults for Docker Compose deployments.
# Only include variables required for services to start.
#
# For a default deployment, copy this file to .env and run:
# docker compose up -d

View File

@ -71,6 +71,8 @@ LOG_TZ=UTC
DEBUG=false
FLASK_DEBUG=false
ENABLE_REQUEST_LOGGING=False
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
WORKFLOW_LOG_CLEANUP_ENABLED=false
WORKFLOW_LOG_RETENTION_DAYS=30
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100

View File

@ -48,16 +48,24 @@ vi.mock('@/context/app-context', () => ({
}),
}))
const mockSetQuery = vi.fn()
const mockSetKeywords = vi.fn()
const mockSetTagIDs = vi.fn()
const mockSetIsCreatedByMe = vi.fn()
const mockSetCategory = vi.fn()
const mockQueryState = {
category: 'all',
tagIDs: [] as string[],
keywords: '',
isCreatedByMe: false,
}
vi.mock('../hooks/use-apps-query-state', () => ({
default: () => ({
isAppListCategory: (value: string) => value === 'all' || Object.values(AppModeEnum).includes(value as AppModeEnum),
useAppsQueryState: () => ({
query: mockQueryState,
setQuery: mockSetQuery,
setCategory: mockSetCategory,
setKeywords: mockSetKeywords,
setTagIDs: mockSetTagIDs,
setIsCreatedByMe: mockSetIsCreatedByMe,
}),
}))
@ -244,6 +252,7 @@ describe('List', () => {
mockServiceState.hasNextPage = false
mockServiceState.isLoading = false
mockServiceState.isFetchingNextPage = false
mockQueryState.category = 'all'
mockQueryState.tagIDs = []
mockQueryState.keywords = ''
mockQueryState.isCreatedByMe = false
@ -317,25 +326,21 @@ describe('List', () => {
})
describe('Tab Navigation', () => {
it('should update URL when workflow tab is clicked', async () => {
const { onUrlUpdate } = renderList()
it('should update category when workflow tab is clicked', () => {
renderList()
fireEvent.click(screen.getByText('app.types.workflow'))
await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(lastCall.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW)
expect(mockSetCategory).toHaveBeenCalledWith(AppModeEnum.WORKFLOW)
})
it('should update URL when all tab is clicked', async () => {
const { onUrlUpdate } = renderList('?category=workflow')
it('should update category when all tab is clicked', () => {
mockQueryState.category = AppModeEnum.WORKFLOW
renderList()
fireEvent.click(screen.getByText('app.types.all'))
await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
// nuqs removes the default value ('all') from URL params
expect(lastCall.searchParams.has('category')).toBe(false)
expect(mockSetCategory).toHaveBeenCalledWith('all')
})
})
@ -351,7 +356,7 @@ describe('List', () => {
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: 'test search' } })
expect(mockSetQuery).toHaveBeenCalled()
expect(mockSetKeywords).toHaveBeenCalledWith('test search')
})
it('should handle search clear button click', () => {
@ -364,7 +369,7 @@ describe('List', () => {
if (clearButton)
fireEvent.click(clearButton)
expect(mockSetQuery).toHaveBeenCalled()
expect(mockSetKeywords).toHaveBeenCalledWith('')
})
})
@ -373,8 +378,9 @@ describe('List', () => {
mockQueryState.tagIDs = ['tag-1']
mockQueryState.keywords = 'sales'
mockQueryState.isCreatedByMe = true
mockQueryState.category = AppModeEnum.WORKFLOW
renderList('?category=workflow')
renderList()
const options = mockAppListInfiniteOptions.mock.calls.at(-1)?.[0] as AppListInfiniteOptions
@ -412,7 +418,7 @@ describe('List', () => {
const checkbox = screen.getByTestId('checkbox-undefined')
fireEvent.click(checkbox)
expect(mockSetQuery).toHaveBeenCalled()
expect(mockSetIsCreatedByMe).toHaveBeenCalledWith(true)
})
})
@ -506,8 +512,8 @@ describe('List', () => {
expect(screen.getByText('app.types.completion'))!.toBeInTheDocument()
})
it('should update URL for each app type tab click', async () => {
const { onUrlUpdate } = renderList()
it('should update category for each app type tab click', () => {
renderList()
const appTypeTexts = [
{ mode: AppModeEnum.WORKFLOW, text: 'app.types.workflow' },
@ -518,11 +524,9 @@ describe('List', () => {
]
for (const { mode, text } of appTypeTexts) {
onUrlUpdate.mockClear()
mockSetCategory.mockClear()
fireEvent.click(screen.getByText(text))
await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(lastCall.searchParams.get('category')).toBe(mode)
expect(mockSetCategory).toHaveBeenCalledWith(mode)
}
})
})

View File

@ -0,0 +1 @@
export const APP_LIST_SEARCH_DEBOUNCE_MS = 500

View File

@ -1,6 +1,8 @@
import { act, waitFor } from '@testing-library/react'
import { renderHookWithNuqs } from '@/test/nuqs-testing'
import useAppsQueryState from '../use-apps-query-state'
import { AppModeEnum } from '@/types/app'
import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../../constants'
import { useAppsQueryState } from '../use-apps-query-state'
const renderWithAdapter = (searchParams = '') => {
return renderHookWithNuqs(() => useAppsQueryState(), { searchParams })
@ -11,214 +13,161 @@ describe('useAppsQueryState', () => {
vi.clearAllMocks()
})
describe('Initialization', () => {
it('should expose query and setQuery when initialized', () => {
const { result } = renderWithAdapter()
it('should expose app list query state actions', () => {
const { result } = renderWithAdapter()
expect(result.current.query).toBeDefined()
expect(typeof result.current.setQuery).toBe('function')
expect(result.current.query).toEqual({
category: 'all',
tagIDs: [],
keywords: '',
isCreatedByMe: false,
})
expect(typeof result.current.setCategory).toBe('function')
expect(typeof result.current.setKeywords).toBe('function')
expect(typeof result.current.setTagIDs).toBe('function')
expect(typeof result.current.setIsCreatedByMe).toBe('function')
})
it('should default to empty filters when search params are missing', () => {
const { result } = renderWithAdapter()
it('should parse app list filters from URL', () => {
const { result } = renderWithAdapter(
'?category=workflow&tagIDs=tag1;tag2&keywords=search+term&isCreatedByMe=true',
)
expect(result.current.query.tagIDs).toBeUndefined()
expect(result.current.query.keywords).toBeUndefined()
expect(result.current.query.isCreatedByMe).toBe(false)
expect(result.current.query).toEqual({
category: AppModeEnum.WORKFLOW,
tagIDs: ['tag1', 'tag2'],
keywords: 'search term',
isCreatedByMe: true,
})
})
describe('Parsing search params', () => {
it('should parse tagIDs when URL includes tagIDs', () => {
const { result } = renderWithAdapter('?tagIDs=tag1;tag2;tag3')
it('should update category URL state', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2', 'tag3'])
act(() => {
result.current.setCategory(AppModeEnum.WORKFLOW)
})
it('should parse keywords when URL includes keywords', () => {
const { result } = renderWithAdapter('?keywords=search+term')
expect(result.current.query.keywords).toBe('search term')
})
it('should parse isCreatedByMe when URL includes true value', () => {
const { result } = renderWithAdapter('?isCreatedByMe=true')
expect(result.current.query.isCreatedByMe).toBe(true)
})
it('should parse all params when URL includes multiple filters', () => {
const { result } = renderWithAdapter(
'?tagIDs=tag1;tag2&keywords=test&isCreatedByMe=true',
)
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
expect(result.current.query.keywords).toBe('test')
expect(result.current.query.isCreatedByMe).toBe(true)
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(result.current.query.category).toBe(AppModeEnum.WORKFLOW)
expect(update.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW)
expect(update.options.history).toBe('push')
})
describe('Updating query state', () => {
it('should update keywords when setQuery receives keywords', () => {
const { result } = renderWithAdapter()
it('should remove category from URL when set to all', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?category=workflow')
act(() => {
result.current.setQuery({ keywords: 'new search' })
})
expect(result.current.query.keywords).toBe('new search')
act(() => {
result.current.setCategory('all')
})
it('should update tagIDs when setQuery receives tagIDs', () => {
const { result } = renderWithAdapter()
act(() => {
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
})
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
})
it('should update isCreatedByMe when setQuery receives true', () => {
const { result } = renderWithAdapter()
act(() => {
result.current.setQuery({ isCreatedByMe: true })
})
expect(result.current.query.isCreatedByMe).toBe(true)
})
it('should support partial updates when setQuery uses callback', () => {
const { result } = renderWithAdapter()
act(() => {
result.current.setQuery({ keywords: 'initial' })
})
act(() => {
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
})
expect(result.current.query.keywords).toBe('initial')
expect(result.current.query.isCreatedByMe).toBe(true)
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(result.current.query.category).toBe('all')
expect(update.searchParams.has('category')).toBe(false)
})
describe('URL synchronization', () => {
it('should sync keywords to URL when keywords change', async () => {
it('should update keywords state immediately while debouncing URL writes', async () => {
vi.useFakeTimers()
try {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.setQuery({ keywords: 'search' })
result.current.setKeywords('search')
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(result.current.query.keywords).toBe('search')
expect(onUrlUpdate).not.toHaveBeenCalled()
await act(async () => {
await vi.advanceTimersByTimeAsync(APP_LIST_SEARCH_DEBOUNCE_MS + 100)
})
expect(onUrlUpdate).toHaveBeenCalled()
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(update.searchParams.get('keywords')).toBe('search')
expect(update.options.history).toBe('push')
})
}
finally {
vi.useRealTimers()
}
})
it('should sync tagIDs to URL when tagIDs change', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2')
})
it('should sync isCreatedByMe to URL when enabled', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.setQuery({ isCreatedByMe: true })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(update.searchParams.get('isCreatedByMe')).toBe('true')
})
it('should remove keywords from URL when keywords are cleared', async () => {
it('should remove keywords from URL when cleared', async () => {
vi.useFakeTimers()
try {
const { result, onUrlUpdate } = renderWithAdapter('?keywords=existing')
act(() => {
result.current.setQuery({ keywords: '' })
result.current.setKeywords('')
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(result.current.query.keywords).toBe('')
await act(async () => {
await vi.advanceTimersByTimeAsync(APP_LIST_SEARCH_DEBOUNCE_MS + 100)
})
expect(onUrlUpdate).toHaveBeenCalled()
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(update.searchParams.has('keywords')).toBe(false)
})
it('should remove tagIDs from URL when tagIDs are empty', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2')
act(() => {
result.current.setQuery({ tagIDs: [] })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(update.searchParams.has('tagIDs')).toBe(false)
})
it('should remove isCreatedByMe from URL when disabled', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true')
act(() => {
result.current.setQuery({ isCreatedByMe: false })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(update.searchParams.has('isCreatedByMe')).toBe(false)
})
}
finally {
vi.useRealTimers()
}
})
describe('Edge cases', () => {
it('should treat empty tagIDs as empty list when URL param is empty', () => {
const { result } = renderWithAdapter('?tagIDs=')
it('should update tag filter URL state', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
expect(result.current.query.tagIDs).toEqual([])
act(() => {
result.current.setTagIDs(['tag1', 'tag2'])
})
it('should treat empty keywords as undefined when URL param is empty', () => {
const { result } = renderWithAdapter('?keywords=')
expect(result.current.query.keywords).toBeUndefined()
})
it('should decode keywords with spaces when URL contains encoded spaces', () => {
const { result } = renderWithAdapter('?keywords=test+with+spaces')
expect(result.current.query.keywords).toBe('test with spaces')
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2')
expect(update.options.history).toBe('push')
})
describe('Integration scenarios', () => {
it('should keep accumulated filters when updates are sequential', () => {
const { result } = renderWithAdapter()
it('should remove tagIDs from URL when empty', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2')
act(() => {
result.current.setQuery({ keywords: 'first' })
})
act(() => {
result.current.setQuery(prev => ({ ...prev, tagIDs: ['tag1'] }))
})
act(() => {
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
})
expect(result.current.query.keywords).toBe('first')
expect(result.current.query.tagIDs).toEqual(['tag1'])
expect(result.current.query.isCreatedByMe).toBe(true)
act(() => {
result.current.setTagIDs([])
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(result.current.query.tagIDs).toEqual([])
expect(update.searchParams.has('tagIDs')).toBe(false)
})
it('should update created-by-me URL state', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.setIsCreatedByMe(true)
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(result.current.query.isCreatedByMe).toBe(true)
expect(update.searchParams.get('isCreatedByMe')).toBe('true')
expect(update.options.history).toBe('push')
})
it('should remove isCreatedByMe from URL when disabled', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true')
act(() => {
result.current.setIsCreatedByMe(false)
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(result.current.query.isCreatedByMe).toBe(false)
expect(update.searchParams.has('isCreatedByMe')).toBe(false)
})
})

View File

@ -1,57 +1,56 @@
import { parseAsArrayOf, parseAsBoolean, parseAsString, useQueryStates } from 'nuqs'
import { debounce, parseAsArrayOf, parseAsBoolean, parseAsString, parseAsStringLiteral, useQueryStates } from 'nuqs'
import { useCallback, useMemo } from 'react'
import { AppModes } from '@/types/app'
import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../constants'
type AppsQuery = {
tagIDs?: string[]
keywords?: string
isCreatedByMe?: boolean
const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const
export type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number]
const appListCategorySet = new Set<string>(APP_LIST_CATEGORY_VALUES)
export const isAppListCategory = (value: string): value is AppListCategory => {
return appListCategorySet.has(value)
}
const normalizeKeywords = (value: string | null) => value || undefined
function useAppsQueryState() {
const [urlQuery, setUrlQuery] = useQueryStates(
{
tagIDs: parseAsArrayOf(parseAsString, ';'),
keywords: parseAsString,
isCreatedByMe: parseAsBoolean,
},
{
history: 'push',
},
)
const query = useMemo<AppsQuery>(() => ({
tagIDs: urlQuery.tagIDs ?? undefined,
keywords: normalizeKeywords(urlQuery.keywords),
isCreatedByMe: urlQuery.isCreatedByMe ?? false,
}), [urlQuery.isCreatedByMe, urlQuery.keywords, urlQuery.tagIDs])
const setQuery = useCallback((next: AppsQuery | ((prev: AppsQuery) => AppsQuery)) => {
const buildPatch = (patch: AppsQuery) => {
const result: Partial<typeof urlQuery> = {}
if ('tagIDs' in patch)
result.tagIDs = patch.tagIDs && patch.tagIDs.length > 0 ? patch.tagIDs : null
if ('keywords' in patch)
result.keywords = patch.keywords ? patch.keywords : null
if ('isCreatedByMe' in patch)
result.isCreatedByMe = patch.isCreatedByMe ? true : null
return result
}
if (typeof next === 'function') {
setUrlQuery(prev => buildPatch(next({
tagIDs: prev.tagIDs ?? undefined,
keywords: normalizeKeywords(prev.keywords),
isCreatedByMe: prev.isCreatedByMe ?? false,
})))
return
}
setUrlQuery(buildPatch(next))
}, [setUrlQuery])
return useMemo(() => ({ query, setQuery }), [query, setQuery])
const appListQueryParsers = {
category: parseAsStringLiteral(APP_LIST_CATEGORY_VALUES)
.withDefault('all')
.withOptions({ history: 'push' }),
tagIDs: parseAsArrayOf(parseAsString, ';')
.withDefault([])
.withOptions({ history: 'push' }),
keywords: parseAsString.withDefault('').withOptions({
limitUrlUpdates: debounce(APP_LIST_SEARCH_DEBOUNCE_MS),
}),
isCreatedByMe: parseAsBoolean
.withDefault(false)
.withOptions({ history: 'push' }),
}
export default useAppsQueryState
export function useAppsQueryState() {
const [query, setQuery] = useQueryStates(appListQueryParsers)
const setCategory = useCallback((category: AppListCategory) => {
setQuery({ category })
}, [setQuery])
const setKeywords = useCallback((keywords: string) => {
setQuery({ keywords })
}, [setQuery])
const setTagIDs = useCallback((tagIDs: string[]) => {
setQuery({ tagIDs })
}, [setQuery])
const setIsCreatedByMe = useCallback((isCreatedByMe: boolean) => {
setQuery({ isCreatedByMe })
}, [setQuery])
return useMemo(() => ({
query,
setCategory,
setKeywords,
setTagIDs,
setIsCreatedByMe,
}), [query, setCategory, setKeywords, setTagIDs, setIsCreatedByMe])
}

View File

@ -4,8 +4,7 @@ import type { FC } from 'react'
import type { AppListQuery } from '@/contract/console/apps'
import { cn } from '@langgenius/dify-ui/cn'
import { keepPreviousData, useInfiniteQuery, useSuspenseQuery } from '@tanstack/react-query'
import { useDebounceFn } from 'ahooks'
import { parseAsStringLiteral, useQueryState } from 'nuqs'
import { useDebounce } from 'ahooks'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Checkbox from '@/app/components/base/checkbox'
@ -18,12 +17,13 @@ import { CheckModal } from '@/hooks/use-pay'
import dynamic from '@/next/dynamic'
import { consoleQuery } from '@/service/client'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { AppModeEnum, AppModes } from '@/types/app'
import { AppModeEnum } from '@/types/app'
import AppCard from './app-card'
import { AppCardSkeleton } from './app-card-skeleton'
import { APP_LIST_SEARCH_DEBOUNCE_MS } from './constants'
import Empty from './empty'
import Footer from './footer'
import useAppsQueryStateHook from './hooks/use-apps-query-state'
import { isAppListCategory, useAppsQueryState } from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import { useWorkflowOnlineUsers } from './hooks/use-workflow-online-users'
import NewAppCard from './new-app-card'
@ -35,18 +35,6 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
ssr: false,
})
const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const
type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number]
const appListCategorySet = new Set<string>(APP_LIST_CATEGORY_VALUES)
const isAppListCategory = (value: string): value is AppListCategory => {
return appListCategorySet.has(value)
}
const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES)
.withDefault('all')
.withOptions({ history: 'push' })
type Props = {
controlRefreshList?: number
}
@ -56,28 +44,21 @@ const List: FC<Props> = ({
const { t } = useTranslation()
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const [activeTab, setActiveTab] = useQueryState(
'category',
parseAsAppListCategory,
)
// eslint-disable-next-line react/use-state -- custom URL query hook, not React.useState
const appsQuery = useAppsQueryStateHook()
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = appsQuery
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
const [searchKeywords, setSearchKeywords] = useState(keywords)
const {
query: { category, tagIDs, keywords, isCreatedByMe },
setCategory,
setKeywords,
setTagIDs,
setIsCreatedByMe,
} = useAppsQueryState()
const debouncedKeywords = useDebounce(keywords, { wait: APP_LIST_SEARCH_DEBOUNCE_MS })
const newAppCardRef = useRef<HTMLDivElement>(null)
const containerRef = useRef<HTMLDivElement>(null)
const [showTagManagementModal, setShowTagManagementModal] = useState(false)
const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false)
const [droppedDSLFile, setDroppedDSLFile] = useState<File | undefined>()
const setKeywords = useCallback((keywords: string) => {
setQuery(prev => ({ ...prev, keywords }))
}, [setQuery])
const setTagIDs = useCallback((tagIDs: string[]) => {
setQuery(prev => ({ ...prev, tagIDs }))
}, [setQuery])
const handleDSLFileDropped = useCallback((file: File) => {
setDroppedDSLFile(file)
@ -93,11 +74,11 @@ const List: FC<Props> = ({
const appListQuery = useMemo<AppListQuery>(() => ({
page: 1,
limit: 30,
name: searchKeywords,
name: debouncedKeywords,
...(tagIDs.length ? { tag_ids: tagIDs } : {}),
...(isCreatedByMe ? { is_created_by_me: isCreatedByMe } : {}),
...(activeTab !== 'all' ? { mode: activeTab } : {}),
}), [activeTab, isCreatedByMe, searchKeywords, tagIDs])
...(category !== 'all' ? { mode: category } : {}),
}), [category, debouncedKeywords, isCreatedByMe, tagIDs])
const {
data,
@ -177,27 +158,9 @@ const List: FC<Props> = ({
return () => observer?.disconnect()
}, [isLoading, isFetchingNextPage, fetchNextPage, error, hasNextPage, isCurrentWorkspaceDatasetOperator])
const { run: handleSearch } = useDebounceFn(() => {
setSearchKeywords(keywords)
}, { wait: 500 })
const handleKeywordsChange = (value: string) => {
setKeywords(value)
handleSearch()
}
const { run: handleTagsUpdate } = useDebounceFn(() => {
setTagIDs(tagFilterValue)
}, { wait: 500 })
const handleTagsChange = (value: string[]) => {
setTagFilterValue(value)
handleTagsUpdate()
}
const handleCreatedByMeChange = useCallback(() => {
const newValue = !isCreatedByMe
setIsCreatedByMe(newValue)
setQuery(prev => ({ ...prev, isCreatedByMe: newValue }))
}, [isCreatedByMe, setQuery])
setIsCreatedByMe(!isCreatedByMe)
}, [isCreatedByMe, setIsCreatedByMe])
const pages = useMemo(() => data?.pages ?? [], [data?.pages])
const apps = useMemo(() => pages.flatMap(({ data: pageApps }) => pageApps), [pages])
@ -232,10 +195,10 @@ const List: FC<Props> = ({
<div className="sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pt-7 pb-5">
<TabSliderNew
value={activeTab}
value={category}
onChange={(nextValue) => {
if (isAppListCategory(nextValue))
setActiveTab(nextValue)
setCategory(nextValue)
}}
options={options}
/>
@ -246,14 +209,14 @@ const List: FC<Props> = ({
{t('showMyCreatedAppsOnly', { ns: 'app' })}
</div>
</label>
<TagFilter type="app" value={tagFilterValue} onChange={handleTagsChange} onOpenTagManagement={() => setShowTagManagementModal(true)} />
<TagFilter type="app" value={tagIDs} onChange={setTagIDs} onOpenTagManagement={() => setShowTagManagementModal(true)} />
<Input
showLeftIcon
showClearIcon
wrapperClassName="w-[200px]"
value={keywords}
onChange={e => handleKeywordsChange(e.target.value)}
onClear={() => handleKeywordsChange('')}
onChange={e => setKeywords(e.target.value)}
onClear={() => setKeywords('')}
/>
</div>
</div>
@ -267,7 +230,7 @@ const List: FC<Props> = ({
ref={newAppCardRef}
isLoading={isLoadingCurrentWorkspace}
onSuccess={refetch}
selectedAppType={activeTab}
selectedAppType={category}
className={cn(!hasAnyApp && 'z-10')}
/>
)}

View File

@ -78,11 +78,11 @@ export const TagFilter = ({
!!value.length && 'pr-6 shadow-xs',
)}
>
<span className="flex min-w-0 items-center gap-1">
<span className="flex w-full min-w-0 items-center gap-1">
<span className="p-px">
<Tag01Icon className="h-3.5 w-3.5 text-text-tertiary" aria-hidden="true" />
</span>
<span className="min-w-0 truncate text-[13px] leading-4.5 text-text-secondary">
<span className="min-w-0 grow truncate text-[13px] leading-4.5 text-text-secondary">
{!value.length && t('tag.placeholder', { ns: 'common' })}
{!!value.length && currentTagName}
</span>