mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 23:18:39 +08:00
Merge branch 'main' into 4-27-app-deploy
This commit is contained in:
commit
0a32344504
@ -55,9 +55,17 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
|
||||
|
||||
## Navigation, Effects, And Performance
|
||||
## You Might Not Need An Effect
|
||||
|
||||
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
|
||||
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
|
||||
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
|
||||
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
|
||||
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
|
||||
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
|
||||
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
|
||||
|
||||
## Navigation And Performance
|
||||
|
||||
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
|
||||
- Treat `useEffect` as a last resort. First try deriving values during render, moving event-driven work into handlers, or using existing hooks/APIs for persistence, subscriptions, media queries, timers, and DOM sync.
|
||||
- Do not use `useEffect` directly in components. If unavoidable, encapsulate it in a purpose-built hook so the component consumes a declarative API.
|
||||
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.
|
||||
|
||||
@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
|
||||
# Ops trace retry configuration
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
|
||||
@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class OpsTraceConfig(BaseSettings):
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field(
|
||||
description="Maximum retry attempts for transient ops trace provider dispatch failures.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field(
|
||||
description="Delay in seconds between transient ops trace provider dispatch retry attempts.",
|
||||
default=5,
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
|
||||
@ -1417,6 +1429,7 @@ class FeatureConfig(
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
OpsTraceConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
|
||||
@ -32,7 +32,7 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -166,6 +166,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
**extract_parent_trace_context_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
|
||||
@ -15,6 +15,7 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -403,8 +404,13 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
extras = self._application_generate_entity.extras
|
||||
external_trace_id = extras.get("external_trace_id")
|
||||
parent_trace_context = extras.get("parent_trace_context")
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
parent_trace_context = parent_trace_context.model_dump(exclude_none=True)
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
@ -412,6 +418,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
parent_trace_context=parent_trace_context,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
|
||||
@ -3,6 +3,17 @@ import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError
|
||||
|
||||
|
||||
class ParentTraceContext(BaseModel):
|
||||
"""Typed parent trace context propagated from an outer workflow tool node."""
|
||||
|
||||
parent_workflow_run_id: StrictStr
|
||||
parent_node_execution_id: StrictStr | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def is_valid_trace_id(trace_id: str) -> bool:
|
||||
"""
|
||||
@ -61,6 +72,30 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]):
|
||||
return {}
|
||||
|
||||
|
||||
def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]:
|
||||
"""
|
||||
Extract 'parent_trace_context' from args.
|
||||
|
||||
Returns a dict suitable for use in extras when both parent identifiers exist.
|
||||
Returns an empty dict if the context is missing or incomplete.
|
||||
"""
|
||||
parent_trace_context = args.get("parent_trace_context")
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
context = parent_trace_context
|
||||
elif isinstance(parent_trace_context, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_trace_context)
|
||||
except ValidationError:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
if context.parent_node_execution_id is None:
|
||||
return {}
|
||||
|
||||
return {"parent_trace_context": context}
|
||||
|
||||
|
||||
def get_trace_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current trace ID from the active OpenTelemetry trace context.
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: str | None = None
|
||||
@ -51,8 +53,8 @@ class BaseTraceInfo(BaseModel):
|
||||
def resolved_parent_context(self) -> tuple[str | None, str | None]:
|
||||
"""Resolve cross-workflow parent linking from metadata.
|
||||
|
||||
Extracts typed parent IDs from the untyped ``parent_trace_context``
|
||||
metadata dict (set by tool_node when invoking nested workflows).
|
||||
Extracts typed parent IDs from the ``parent_trace_context`` metadata
|
||||
payload (set by tool_node when invoking nested workflows).
|
||||
|
||||
Returns:
|
||||
(trace_correlation_override, parent_span_id_source) where
|
||||
@ -60,13 +62,18 @@ class BaseTraceInfo(BaseModel):
|
||||
parent_span_id_source is the outer node_execution_id.
|
||||
"""
|
||||
parent_ctx = self.metadata.get("parent_trace_context")
|
||||
if not isinstance(parent_ctx, dict):
|
||||
if isinstance(parent_ctx, ParentTraceContext):
|
||||
context = parent_ctx
|
||||
elif isinstance(parent_ctx, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_ctx)
|
||||
except ValueError:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
trace_override = parent_ctx.get("parent_workflow_run_id")
|
||||
parent_span = parent_ctx.get("parent_node_execution_id")
|
||||
return (
|
||||
trace_override if isinstance(trace_override, str) else None,
|
||||
parent_span if isinstance(parent_span, str) else None,
|
||||
context.parent_workflow_run_id,
|
||||
context.parent_node_execution_id,
|
||||
)
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
|
||||
22
api/core/ops/exceptions.py
Normal file
22
api/core/ops/exceptions.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Core exceptions shared by ops trace dispatchers and trace providers.
|
||||
|
||||
Provider packages may raise these types to request generic task behavior, but
|
||||
generic Celery tasks should not import provider-specific exception classes.
|
||||
"""
|
||||
|
||||
|
||||
class RetryableTraceDispatchError(RuntimeError):
|
||||
"""Base class for transient trace dispatch failures that Celery may retry."""
|
||||
|
||||
|
||||
class PendingTraceParentContextError(RetryableTraceDispatchError):
|
||||
"""Raised when a nested trace arrives before its parent span context is available."""
|
||||
|
||||
parent_node_execution_id: str
|
||||
|
||||
def __init__(self, parent_node_execution_id: str) -> None:
|
||||
self.parent_node_execution_id = parent_node_execution_id
|
||||
super().__init__(
|
||||
"Pending trace parent context for parent_node_execution_id="
|
||||
f"{parent_node_execution_id}. Retry after the parent span context is published."
|
||||
)
|
||||
@ -16,6 +16,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
BaseTracingConfig,
|
||||
@ -52,6 +53,17 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | None:
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
return parent_trace_context.model_dump(exclude_none=True)
|
||||
if isinstance(parent_trace_context, dict):
|
||||
try:
|
||||
return ParentTraceContext.model_validate(parent_trace_context).model_dump(exclude_none=True)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class _AppTracingConfig(TypedDict, total=False):
|
||||
enabled: bool
|
||||
tracing_provider: str | None
|
||||
@ -857,8 +869,9 @@ class TraceTask:
|
||||
}
|
||||
|
||||
parent_trace_context = self.kwargs.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
|
||||
if dumped_parent_trace_context:
|
||||
metadata["parent_trace_context"] = dumped_parent_trace_context
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
@ -1371,13 +1384,14 @@ class TraceTask:
|
||||
}
|
||||
|
||||
parent_trace_context = node_data.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
|
||||
if dumped_parent_trace_context:
|
||||
metadata["parent_trace_context"] = dumped_parent_trace_context
|
||||
|
||||
message_id: str | None = None
|
||||
conversation_id = node_data.get("conversation_id")
|
||||
workflow_execution_id = node_data.get("workflow_execution_id")
|
||||
if conversation_id and workflow_execution_id and not parent_trace_context:
|
||||
if conversation_id and workflow_execution_id and not dumped_parent_trace_context:
|
||||
with Session(db.engine) as session:
|
||||
msg_id = session.scalar(
|
||||
select(Message.id).where(
|
||||
|
||||
@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -36,6 +37,8 @@ class WorkflowTool(Tool):
|
||||
Workflow tool.
|
||||
"""
|
||||
|
||||
_parent_trace_context: ParentTraceContext | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_app_id: str,
|
||||
@ -54,6 +57,7 @@ class WorkflowTool(Tool):
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
self._parent_trace_context = None
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
@ -94,11 +98,17 @@ class WorkflowTool(Tool):
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
generator_args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
|
||||
if self._parent_trace_context:
|
||||
generator_args.update(
|
||||
extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context})
|
||||
)
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
args=generator_args,
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
@ -194,7 +204,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
forked = self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
@ -204,6 +214,24 @@ class WorkflowTool(Tool):
|
||||
version=self.version,
|
||||
label=self.label,
|
||||
)
|
||||
forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None
|
||||
return forked
|
||||
|
||||
def set_parent_trace_context(
|
||||
self,
|
||||
*,
|
||||
parent_workflow_run_id: str,
|
||||
parent_node_execution_id: str,
|
||||
) -> None:
|
||||
"""Attach outer workflow trace context without exposing it as tool input."""
|
||||
self._parent_trace_context = ParentTraceContext(
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
parent_node_execution_id=parent_node_execution_id,
|
||||
)
|
||||
|
||||
def clear_parent_trace_context(self) -> None:
|
||||
"""Remove parent trace context before invoking this tool outside a nested workflow."""
|
||||
self._parent_trace_context = None
|
||||
|
||||
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
@ -358,6 +359,7 @@ class _WorkflowToolRuntimeBinding:
|
||||
|
||||
tool: Tool
|
||||
conversation_id: str | None = None
|
||||
parent_trace_context: ParentTraceContext | None = None
|
||||
|
||||
|
||||
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
@ -398,7 +400,25 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
conversation_id = (
|
||||
None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID)
|
||||
)
|
||||
return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id))
|
||||
parent_trace_context: ParentTraceContext | None = None
|
||||
if self._is_workflow_tool_provider(node_data):
|
||||
outer_workflow_run_id = (
|
||||
None
|
||||
if variable_pool is None
|
||||
else get_system_text(variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
)
|
||||
if isinstance(outer_workflow_run_id, str) and isinstance(node_execution_id, str):
|
||||
parent_trace_context = ParentTraceContext(
|
||||
parent_workflow_run_id=outer_workflow_run_id,
|
||||
parent_node_execution_id=node_execution_id,
|
||||
)
|
||||
return ToolRuntimeHandle(
|
||||
raw=_WorkflowToolRuntimeBinding(
|
||||
tool=tool_runtime,
|
||||
conversation_id=conversation_id,
|
||||
parent_trace_context=parent_trace_context,
|
||||
)
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
@ -422,6 +442,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
runtime_binding = self._binding_from_handle(tool_runtime)
|
||||
tool = runtime_binding.tool
|
||||
callback = DifyWorkflowCallbackHandler()
|
||||
if runtime_binding.parent_trace_context and hasattr(tool, "set_parent_trace_context"):
|
||||
tool.set_parent_trace_context(
|
||||
parent_workflow_run_id=runtime_binding.parent_trace_context.parent_workflow_run_id,
|
||||
parent_node_execution_id=runtime_binding.parent_trace_context.parent_node_execution_id,
|
||||
)
|
||||
elif hasattr(tool, "clear_parent_trace_context"):
|
||||
tool.clear_parent_trace_context()
|
||||
|
||||
try:
|
||||
messages = ToolEngine.generic_invoke(
|
||||
@ -514,6 +541,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
credential_id=node_data.credential_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_workflow_tool_provider(node_data: ToolNodeData) -> bool:
|
||||
return node_data.provider_type.value == CoreToolProviderType.WORKFLOW.value
|
||||
|
||||
def _adapt_messages(
|
||||
self,
|
||||
messages: Generator[CoreToolInvokeMessage, None, None],
|
||||
|
||||
@ -1,19 +1,22 @@
|
||||
"""Workflow comment models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from models.base import TypeBase
|
||||
|
||||
from .account import Account
|
||||
from .base import Base, gen_uuidv7_string
|
||||
from .base import gen_uuidv7_string
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class WorkflowComment(Base):
|
||||
class WorkflowComment(TypeBase):
|
||||
"""Workflow comment model for canvas commenting functionality.
|
||||
|
||||
Comments are associated with apps rather than specific workflow versions,
|
||||
@ -42,27 +45,33 @@ class WorkflowComment(Base):
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(sa.Float)
|
||||
position_y: Mapped[float] = mapped_column(sa.Float)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime, default=None)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
|
||||
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
# Relationships
|
||||
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
|
||||
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
|
||||
replies: Mapped[list[WorkflowCommentReply]] = relationship(
|
||||
lambda: WorkflowCommentReply, back_populates="comment", cascade="all, delete-orphan", init=False
|
||||
)
|
||||
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
|
||||
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
|
||||
mentions: Mapped[list[WorkflowCommentMention]] = relationship(
|
||||
lambda: WorkflowCommentMention, back_populates="comment", cascade="all, delete-orphan", init=False
|
||||
)
|
||||
|
||||
@property
|
||||
@ -131,7 +140,7 @@ class WorkflowComment(Base):
|
||||
return participants
|
||||
|
||||
|
||||
class WorkflowCommentReply(Base):
|
||||
class WorkflowCommentReply(TypeBase):
|
||||
"""Workflow comment reply model.
|
||||
|
||||
Attributes:
|
||||
@ -149,18 +158,24 @@ class WorkflowCommentReply(Base):
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="replies", init=False)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@ -174,7 +189,7 @@ class WorkflowCommentReply(Base):
|
||||
self._created_by_account_cache = account
|
||||
|
||||
|
||||
class WorkflowCommentMention(Base):
|
||||
class WorkflowCommentMention(TypeBase):
|
||||
"""Workflow comment mention model.
|
||||
|
||||
Mentions are only for internal accounts since end users
|
||||
@ -194,18 +209,18 @@ class WorkflowCommentMention(Base):
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True, default=None
|
||||
)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="mentions", init=False)
|
||||
reply: Mapped[WorkflowCommentReply | None] = relationship(lambda: WorkflowCommentReply, init=False)
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Protocol, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import (
|
||||
@ -19,7 +21,7 @@ from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.attributes import exception_attributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace import Span, Status, StatusCode, get_current_span, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -36,16 +38,106 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.exceptions import PendingTraceParentContextError
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This parent-span carrier store is intentionally Phoenix-local for the current
|
||||
# nested workflow tracing feature. If other trace providers need the same
|
||||
# cross-task parent restoration behavior, move the storage and retry signaling
|
||||
# behind a core trace coordination interface instead of duplicating it here.
|
||||
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS = 300
|
||||
_TRACEPARENT_PATTERN = re.compile(
|
||||
r"^(?P<version>[0-9a-f]{2})-(?P<trace_id>[0-9a-f]{32})-(?P<span_id>[0-9a-f]{16})-(?P<flags>[0-9a-f]{2})$"
|
||||
)
|
||||
|
||||
|
||||
def _phoenix_parent_span_redis_key(parent_node_execution_id: str) -> str:
|
||||
"""Build the Redis key that stores a restorable Phoenix parent span carrier."""
|
||||
return f"trace:phoenix:parent_span:{parent_node_execution_id}"
|
||||
|
||||
|
||||
def _publish_parent_span_context(parent_node_execution_id: str, carrier: Mapping[str, str]) -> None:
|
||||
"""Persist a tracecontext carrier so nested workflow spans can restore the tool span parent."""
|
||||
redis_client.setex(
|
||||
_phoenix_parent_span_redis_key(parent_node_execution_id),
|
||||
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS,
|
||||
safe_json_dumps(dict(carrier)),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_published_parent_span_context(parent_node_execution_id: str) -> dict[str, str]:
|
||||
"""Load a previously published tool-span carrier for nested workflow parenting."""
|
||||
raw_carrier = redis_client.get(_phoenix_parent_span_redis_key(parent_node_execution_id))
|
||||
if raw_carrier is None:
|
||||
raise PendingTraceParentContextError(parent_node_execution_id)
|
||||
|
||||
if isinstance(raw_carrier, bytes):
|
||||
raw_carrier = raw_carrier.decode("utf-8")
|
||||
|
||||
carrier = json.loads(raw_carrier)
|
||||
if not isinstance(carrier, dict):
|
||||
raise ValueError(
|
||||
"Phoenix parent span context must be stored as a JSON object: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
normalized_carrier = {str(key): str(value) for key, value in carrier.items()}
|
||||
if not normalized_carrier:
|
||||
raise ValueError(
|
||||
f"Phoenix parent span context payload is empty: parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
traceparent = normalized_carrier.get("traceparent")
|
||||
if not isinstance(traceparent, str):
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload is missing traceparent: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
traceparent_match = _TRACEPARENT_PATTERN.fullmatch(traceparent)
|
||||
if traceparent_match is None:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has invalid traceparent format: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
if traceparent_match.group("version") == "ff":
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has unsupported traceparent version: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
if traceparent_match.group("trace_id") == "0" * 32:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has zero trace_id in traceparent: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
if traceparent_match.group("span_id") == "0" * 16:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has zero span_id in traceparent: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
extracted_context = TraceContextTextMapPropagator().extract(carrier=normalized_carrier)
|
||||
extracted_span_context = get_current_span(extracted_context).get_span_context()
|
||||
if not extracted_span_context.is_valid or not extracted_span_context.is_remote:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload could not be restored into a valid parent span: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
return normalized_carrier
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
@ -177,6 +269,246 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
|
||||
|
||||
def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str:
|
||||
"""Resolve the workflow session ID for Phoenix workflow spans."""
|
||||
if trace_info.conversation_id:
|
||||
return trace_info.conversation_id
|
||||
|
||||
parent_workflow_run_id, _ = _resolve_workflow_parent_context(trace_info)
|
||||
if parent_workflow_run_id:
|
||||
return parent_workflow_run_id
|
||||
|
||||
return trace_info.workflow_run_id
|
||||
|
||||
|
||||
def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]:
|
||||
"""Expose the typed parent context already resolved on the trace info."""
|
||||
return trace_info.resolved_parent_context
|
||||
|
||||
|
||||
def _resolve_workflow_root_trace_id(trace_info: WorkflowTraceInfo) -> str:
|
||||
"""Resolve the canonical root trace ID for Phoenix workflow spans."""
|
||||
trace_correlation_override, _ = _resolve_workflow_parent_context(trace_info)
|
||||
return trace_correlation_override or trace_info.resolved_trace_id or trace_info.workflow_run_id
|
||||
|
||||
|
||||
class _NodeExecutionIdentityLike(Protocol):
|
||||
@property
|
||||
def node_execution_id(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def node_id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def predecessor_node_id(self) -> str | None: ...
|
||||
|
||||
|
||||
class _NodeExecutionLike(_NodeExecutionIdentityLike, Protocol):
|
||||
@property
|
||||
def id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def node_type(self) -> str: ...
|
||||
|
||||
@property
|
||||
def title(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def process_data(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def status(self) -> WorkflowNodeExecutionStatus: ...
|
||||
|
||||
@property
|
||||
def error(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float | None: ...
|
||||
|
||||
@property
|
||||
def metadata(self) -> Mapping[Any, Any] | None: ...
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime | None: ...
|
||||
|
||||
|
||||
_PHOENIX_STRUCTURED_NODE_TYPES = frozenset({"start", "end", "loop", "iteration"})
|
||||
|
||||
|
||||
def _resolve_workflow_span_name(trace_info: WorkflowTraceInfo) -> str:
|
||||
"""Resolve the Phoenix workflow span display name."""
|
||||
workflow_run_id = trace_info.workflow_run_id.strip() if trace_info.workflow_run_id else ""
|
||||
if workflow_run_id:
|
||||
return f"{TraceTaskName.WORKFLOW_TRACE.value}_{workflow_run_id}"
|
||||
return TraceTaskName.WORKFLOW_TRACE.value
|
||||
|
||||
|
||||
def _build_node_title_by_id(trace_info: WorkflowTraceInfo) -> dict[str, str]:
|
||||
"""Build an authoritative node-title index from the persisted workflow graph."""
|
||||
workflow_data = trace_info.workflow_data
|
||||
workflow_graph = getattr(workflow_data, "graph_dict", None)
|
||||
if not isinstance(workflow_graph, Mapping):
|
||||
workflow_graph = workflow_data.get("graph") if isinstance(workflow_data, Mapping) else None
|
||||
if not isinstance(workflow_graph, Mapping):
|
||||
return {}
|
||||
|
||||
graph_nodes = workflow_graph.get("nodes")
|
||||
if not isinstance(graph_nodes, Sequence):
|
||||
return {}
|
||||
|
||||
node_title_by_id: dict[str, str] = {}
|
||||
for graph_node in graph_nodes:
|
||||
if not isinstance(graph_node, Mapping):
|
||||
continue
|
||||
node_id = graph_node.get("id")
|
||||
node_data = graph_node.get("data")
|
||||
if not isinstance(node_id, str) or not isinstance(node_data, Mapping):
|
||||
continue
|
||||
node_title = node_data.get("title")
|
||||
if isinstance(node_title, str) and node_title.strip():
|
||||
node_title_by_id[node_id] = node_title.strip()
|
||||
|
||||
return node_title_by_id
|
||||
|
||||
|
||||
def _resolve_workflow_node_span_name(
|
||||
node_execution: _NodeExecutionLike,
|
||||
node_title_by_id: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Resolve the Phoenix workflow node span display name."""
|
||||
node_type = str(node_execution.node_type or "")
|
||||
graph_node_title = None
|
||||
if node_title_by_id is not None and isinstance(node_execution.node_id, str):
|
||||
graph_node_title = node_title_by_id.get(node_execution.node_id)
|
||||
|
||||
node_title = graph_node_title or (node_execution.title.strip() if isinstance(node_execution.title, str) else "")
|
||||
if node_title:
|
||||
return f"{node_type}_{node_title}"
|
||||
return node_type
|
||||
|
||||
|
||||
def _get_node_execution_id(node_execution: _NodeExecutionIdentityLike) -> str:
|
||||
"""Return the stable execution identifier for a workflow node execution."""
|
||||
return str(getattr(node_execution, "id", None) or node_execution.node_execution_id)
|
||||
|
||||
|
||||
def _build_execution_id_by_node_id(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
|
||||
"""Index unique workflow graph node ids by execution id.
|
||||
|
||||
This Phoenix-local hierarchy reconstruction intentionally drops ambiguous
|
||||
node ids instead of guessing based on repository order. That keeps parent
|
||||
selection deterministic until upstream tracing exposes explicit parent span
|
||||
data for repeated executions.
|
||||
"""
|
||||
execution_id_by_node_id: dict[str, str] = {}
|
||||
ambiguous_node_ids: set[str] = set()
|
||||
|
||||
for node_execution in node_executions:
|
||||
node_id = node_execution.node_id
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
execution_id = _get_node_execution_id(node_execution)
|
||||
|
||||
if node_id in ambiguous_node_ids:
|
||||
continue
|
||||
|
||||
existing_execution_id = execution_id_by_node_id.get(node_id)
|
||||
if existing_execution_id is None:
|
||||
execution_id_by_node_id[node_id] = execution_id
|
||||
continue
|
||||
|
||||
if existing_execution_id != execution_id:
|
||||
ambiguous_node_ids.add(node_id)
|
||||
execution_id_by_node_id.pop(node_id, None)
|
||||
|
||||
return execution_id_by_node_id
|
||||
|
||||
|
||||
def _build_graph_parent_index(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
|
||||
"""Build an execution-id parent index from predecessor node ids."""
|
||||
execution_id_by_node_id = _build_execution_id_by_node_id(node_executions)
|
||||
graph_parent_index: dict[str, str] = {}
|
||||
|
||||
for node_execution in node_executions:
|
||||
predecessor_node_id = node_execution.predecessor_node_id
|
||||
if not isinstance(predecessor_node_id, str):
|
||||
continue
|
||||
|
||||
predecessor_execution_id = execution_id_by_node_id.get(predecessor_node_id)
|
||||
if predecessor_execution_id is not None:
|
||||
execution_id = _get_node_execution_id(node_execution)
|
||||
graph_parent_index[execution_id] = predecessor_execution_id
|
||||
|
||||
return graph_parent_index
|
||||
|
||||
|
||||
def _resolve_structured_parent_execution_id(
|
||||
node_execution: object, execution_id_by_node_id: Mapping[str, str]
|
||||
) -> str | None:
|
||||
"""Resolve Phoenix-local structured parents from loop/iteration node ids.
|
||||
|
||||
Any execution carrying ``iteration_id`` or ``loop_id`` belongs to an
|
||||
enclosing structured node. When predecessor node ids are ambiguous because
|
||||
the graph node repeats inside that structure, Phoenix can still keep the
|
||||
child span under the enclosing loop/iteration span without relying on
|
||||
execution-order heuristics.
|
||||
"""
|
||||
execution_metadata = getattr(node_execution, "execution_metadata_dict", None)
|
||||
if not isinstance(execution_metadata, Mapping):
|
||||
execution_metadata = getattr(node_execution, "metadata", None)
|
||||
if not isinstance(execution_metadata, Mapping):
|
||||
execution_metadata = {}
|
||||
|
||||
for enclosing_node_id in (
|
||||
getattr(node_execution, "iteration_id", None),
|
||||
getattr(node_execution, "loop_id", None),
|
||||
execution_metadata.get("iteration_id"),
|
||||
execution_metadata.get("loop_id"),
|
||||
):
|
||||
if not isinstance(enclosing_node_id, str):
|
||||
continue
|
||||
|
||||
enclosing_execution_id = execution_id_by_node_id.get(enclosing_node_id)
|
||||
if enclosing_execution_id is not None:
|
||||
return enclosing_execution_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_node_parent(
|
||||
execution_id: str,
|
||||
predecessor_execution_id: str | None,
|
||||
structured_parent_execution_id: str | None,
|
||||
span_by_execution_id: Mapping[str, Span],
|
||||
graph_parent_index: Mapping[str, str],
|
||||
workflow_span: Span,
|
||||
) -> Span:
|
||||
"""Resolve the parent span for a workflow node execution."""
|
||||
if predecessor_execution_id is not None:
|
||||
predecessor_span = span_by_execution_id.get(predecessor_execution_id)
|
||||
if predecessor_span is not None:
|
||||
return predecessor_span
|
||||
|
||||
graph_parent_execution_id = graph_parent_index.get(execution_id)
|
||||
if graph_parent_execution_id is not None:
|
||||
graph_parent_span = span_by_execution_id.get(graph_parent_execution_id)
|
||||
if graph_parent_span is not None:
|
||||
return graph_parent_span
|
||||
|
||||
if structured_parent_execution_id is not None:
|
||||
structured_parent_span = span_by_execution_id.get(structured_parent_execution_id)
|
||||
if structured_parent_span is not None:
|
||||
return structured_parent_span
|
||||
|
||||
return workflow_span
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
@ -189,6 +521,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
self.root_span_carriers: dict[str, dict[str, str]] = {}
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
@ -235,13 +569,41 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
file_list=safe_json_dumps(file_list),
|
||||
query=trace_info.query or "",
|
||||
)
|
||||
workflow_session_id = _resolve_workflow_session_id(trace_info)
|
||||
parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info)
|
||||
logger.info(
|
||||
"[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s "
|
||||
"parent_workflow_run_id=%s parent_node_execution_id=%s resolved_session_id=%s",
|
||||
trace_info.workflow_run_id,
|
||||
trace_info.conversation_id,
|
||||
parent_workflow_run_id,
|
||||
parent_node_execution_id,
|
||||
workflow_session_id,
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
if parent_node_execution_id:
|
||||
workflow_parent_carrier = _resolve_published_parent_span_context(parent_node_execution_id)
|
||||
else:
|
||||
root_trace_id = _resolve_workflow_root_trace_id(trace_info)
|
||||
workflow_root_span_name: str | None = trace_info.workflow_run_id
|
||||
if not isinstance(workflow_root_span_name, str) or not workflow_root_span_name.strip():
|
||||
workflow_root_span_name = None
|
||||
|
||||
workflow_parent_carrier = self.ensure_root_span(
|
||||
root_trace_id,
|
||||
root_span_name=workflow_root_span_name,
|
||||
root_span_attributes={
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
},
|
||||
)
|
||||
|
||||
workflow_span_context = self.propagator.extract(carrier=workflow_parent_carrier)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=_resolve_workflow_span_name(trace_info),
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
@ -249,10 +611,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
SpanAttributes.SESSION_ID: workflow_session_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
@ -276,16 +638,50 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
node_title_by_id = _build_node_title_by_id(trace_info)
|
||||
execution_id_by_node_id = _build_execution_id_by_node_id(workflow_node_executions)
|
||||
graph_parent_index = _build_graph_parent_index(workflow_node_executions)
|
||||
node_execution_by_execution_id = {
|
||||
_get_node_execution_id(node_execution): node_execution for node_execution in workflow_node_executions
|
||||
}
|
||||
span_by_execution_id: dict[str, Span] = {}
|
||||
emitting_execution_ids: set[str] = set()
|
||||
|
||||
workflow_span_error: Exception | str | None = trace_info.error
|
||||
try:
|
||||
for node_execution in workflow_node_executions:
|
||||
|
||||
def emit_node_span(node_execution: _NodeExecutionLike) -> Span:
|
||||
execution_id = _get_node_execution_id(node_execution)
|
||||
existing_span = span_by_execution_id.get(execution_id)
|
||||
if existing_span is not None:
|
||||
return existing_span
|
||||
|
||||
graph_parent_execution_id = graph_parent_index.get(execution_id)
|
||||
structured_parent_execution_id = _resolve_structured_parent_execution_id(
|
||||
node_execution, execution_id_by_node_id
|
||||
)
|
||||
|
||||
if execution_id not in emitting_execution_ids:
|
||||
emitting_execution_ids.add(execution_id)
|
||||
try:
|
||||
for parent_execution_id in (graph_parent_execution_id, structured_parent_execution_id):
|
||||
if parent_execution_id is None or parent_execution_id == execution_id:
|
||||
continue
|
||||
if parent_execution_id in span_by_execution_id:
|
||||
continue
|
||||
parent_node_execution = node_execution_by_execution_id.get(parent_execution_id)
|
||||
if parent_node_execution is not None:
|
||||
emit_node_span(parent_node_execution)
|
||||
finally:
|
||||
emitting_execution_ids.discard(execution_id)
|
||||
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
elapsed_time = node_execution.elapsed_time or 0
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
@ -324,9 +720,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
parent_span = _resolve_node_parent(
|
||||
execution_id=execution_id,
|
||||
predecessor_execution_id=None,
|
||||
structured_parent_execution_id=structured_parent_execution_id,
|
||||
span_by_execution_id=span_by_execution_id,
|
||||
graph_parent_index=graph_parent_index,
|
||||
workflow_span=workflow_span,
|
||||
)
|
||||
workflow_span_context = set_span_in_context(parent_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
name=_resolve_workflow_node_span_name(node_execution, node_title_by_id),
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
@ -334,13 +738,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
SpanAttributes.SESSION_ID: workflow_session_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
span_by_execution_id[execution_id] = node_span
|
||||
node_span_error: Exception | str | None = None
|
||||
try:
|
||||
if node_execution.node_type == "tool":
|
||||
parent_span_carrier: dict[str, str] = {}
|
||||
with use_span(node_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=parent_span_carrier)
|
||||
_publish_parent_span_context(execution_id, parent_span_carrier)
|
||||
|
||||
if node_execution.node_type == "llm":
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
@ -362,17 +773,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
)
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
except Exception as e:
|
||||
node_span_error = e
|
||||
raise
|
||||
finally:
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_span_error is not None:
|
||||
set_span_status(node_span, node_span_error)
|
||||
elif node_execution.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
return node_span
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
emit_node_span(node_execution)
|
||||
except Exception as e:
|
||||
workflow_span_error = e
|
||||
raise
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
set_span_status(workflow_span, workflow_span_error)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
@ -735,22 +1155,39 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
def ensure_root_span(
|
||||
self,
|
||||
dify_trace_id: str | None,
|
||||
*,
|
||||
root_span_name: str | None = None,
|
||||
root_span_attributes: Mapping[str, AttributeValue] | None = None,
|
||||
):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
trace_key = str(dify_trace_id)
|
||||
if trace_key not in self.dify_trace_ids:
|
||||
carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
span_name = root_span_name.strip() if isinstance(root_span_name, str) and root_span_name.strip() else "Dify"
|
||||
root_span_attributes_dict: dict[str, AttributeValue] = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
"dify_project_name": str(self.project),
|
||||
"dify_trace_id": trace_key,
|
||||
}
|
||||
if root_span_attributes:
|
||||
root_span_attributes_dict.update(root_span_attributes)
|
||||
|
||||
root_span = self.tracer.start_span(name=span_name, attributes=root_span_attributes_dict)
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
self.propagator.inject(carrier=carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
self.dify_trace_ids.add(trace_key)
|
||||
self.root_span_carriers[trace_key] = carrier
|
||||
|
||||
self.carrier = self.root_span_carriers[trace_key]
|
||||
return self.carrier
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,36 +0,0 @@
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues
|
||||
|
||||
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
|
||||
|
||||
|
||||
class TestGetNodeSpanKind:
|
||||
"""Tests for _get_node_span_kind helper."""
|
||||
|
||||
def test_all_node_types_are_mapped_correctly(self):
|
||||
"""Ensure every built-in node type is mapped to the correct span kind."""
|
||||
# Mappings for node types that have a specialised span kind.
|
||||
special_mappings = {
|
||||
BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER,
|
||||
BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL,
|
||||
BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT,
|
||||
}
|
||||
|
||||
# Test that every built-in node type is mapped to the correct span kind.
|
||||
# Node types not in `special_mappings` should default to CHAIN.
|
||||
for node_type in BUILT_IN_NODE_TYPES:
|
||||
expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
actual_span_kind = _get_node_span_kind(node_type)
|
||||
assert actual_span_kind == expected_span_kind, (
|
||||
f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected."
|
||||
)
|
||||
|
||||
def test_unknown_string_defaults_to_chain(self):
|
||||
"""An unrecognised node type string should still return CHAIN."""
|
||||
assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
def test_stale_dataset_retrieval_not_in_mapping(self):
|
||||
"""The old 'dataset_retrieval' string was never a valid NodeType value;
|
||||
make sure it is not present in the mapping dictionary."""
|
||||
assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND
|
||||
@ -1,11 +1,31 @@
|
||||
"""
|
||||
Celery task for asynchronous ops trace dispatch.
|
||||
|
||||
Trace providers may report explicitly retryable dispatch failures through the
|
||||
core retryable exception contract. The task preserves the payload file only
|
||||
when Celery accepts the retry request; successful dispatches and terminal
|
||||
failures clean up the stored payload.
|
||||
|
||||
One concrete producer today is Phoenix nested workflow tracing. The outer
|
||||
workflow tool span publishes a restorable parent span context asynchronously,
|
||||
while the nested workflow trace may be picked up by Celery first. In that
|
||||
ordering window, the provider raises a retryable core exception instead of
|
||||
dropping the trace or emitting it under the wrong parent. The task intentionally
|
||||
does not know that the provider is Phoenix; it only honors the core retryable
|
||||
dispatch contract.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from celery.exceptions import Retry
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
|
||||
from core.ops.entities.trace_entity import trace_info_info_map
|
||||
from core.ops.exceptions import RetryableTraceDispatchError
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -14,9 +34,17 @@ from models.workflow import WorkflowRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRYABLE_TRACE_DISPATCH_LIMIT = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES
|
||||
_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS
|
||||
|
||||
@shared_task(queue="ops_trace")
|
||||
def process_trace_tasks(file_info):
|
||||
|
||||
@shared_task(
|
||||
queue="ops_trace",
|
||||
bind=True,
|
||||
max_retries=_RETRYABLE_TRACE_DISPATCH_LIMIT,
|
||||
default_retry_delay=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS,
|
||||
)
|
||||
def process_trace_tasks(self, file_info):
|
||||
"""
|
||||
Async process trace tasks
|
||||
Usage: process_trace_tasks.delay(tasks_data)
|
||||
@ -29,6 +57,7 @@ def process_trace_tasks(file_info):
|
||||
file_data = json.loads(storage.load(file_path))
|
||||
trace_info = file_data.get("trace_info")
|
||||
trace_info_type = file_data.get("trace_info_type")
|
||||
enterprise_trace_dispatched = bool(file_data.get("_enterprise_trace_dispatched"))
|
||||
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
if trace_info.get("message_data"):
|
||||
@ -38,6 +67,8 @@ def process_trace_tasks(file_info):
|
||||
if trace_info.get("documents"):
|
||||
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
|
||||
|
||||
should_delete_file = True
|
||||
|
||||
try:
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
@ -45,30 +76,66 @@ def process_trace_tasks(file_info):
|
||||
|
||||
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
|
||||
|
||||
if is_ee_telemetry_enabled():
|
||||
if is_ee_telemetry_enabled() and not enterprise_trace_dispatched:
|
||||
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
|
||||
|
||||
try:
|
||||
EnterpriseOtelTrace().trace(trace_info)
|
||||
except Exception:
|
||||
logger.exception("Enterprise trace failed for app_id: %s", app_id)
|
||||
else:
|
||||
file_data["_enterprise_trace_dispatched"] = True
|
||||
enterprise_trace_dispatched = True
|
||||
|
||||
if trace_instance:
|
||||
with current_app.app_context():
|
||||
trace_instance.trace(trace_info)
|
||||
|
||||
logger.info("Processing trace tasks success, app_id: %s", app_id)
|
||||
except RetryableTraceDispatchError as e:
|
||||
# Retryable dispatch failures represent a transient provider-side
|
||||
# ordering gap, not corrupt payload data. Keep the payload only after
|
||||
# Celery accepts the retry request; otherwise this attempt becomes a
|
||||
# terminal failure and the stored file is cleaned up in `finally`.
|
||||
#
|
||||
# Enterprise telemetry runs before provider dispatch. If it already ran
|
||||
# and provider dispatch asks for a retry, persist that private flag so
|
||||
# the next attempt does not emit the same enterprise trace twice.
|
||||
if self.request.retries >= _RETRYABLE_TRACE_DISPATCH_LIMIT:
|
||||
logger.exception("Retryable trace dispatch budget exhausted, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
else:
|
||||
logger.warning(
|
||||
"Retryable trace dispatch failure, scheduling retry %s/%s for app_id %s: %s",
|
||||
self.request.retries + 1,
|
||||
_RETRYABLE_TRACE_DISPATCH_LIMIT,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
if enterprise_trace_dispatched:
|
||||
storage.save(file_path, json.dumps(file_data).encode("utf-8"))
|
||||
raise self.retry(exc=e, countdown=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS)
|
||||
except Retry:
|
||||
should_delete_file = False
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to schedule trace dispatch retry, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
except Exception as e:
|
||||
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
finally:
|
||||
try:
|
||||
storage.delete(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete trace file %s for app_id %s: %s",
|
||||
file_path,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
if should_delete_file:
|
||||
try:
|
||||
storage.delete(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete trace file %s for app_id %s: %s",
|
||||
file_path,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -24,6 +25,76 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_generate_includes_parent_trace_context_in_extras(monkeypatch):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator._bind_file_access_scope",
|
||||
lambda *args, **kwargs: contextlib.nullcontext(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config",
|
||||
lambda *args, **kwargs: SimpleNamespace(
|
||||
app_id="app-1", tenant_id="tenant-1", workflow_id="workflow-1", variables=[]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.file_factory.build_from_mappings", lambda *args, **kwargs: []
|
||||
)
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.TraceQueueManager", MagicMock())
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
MagicMock(return_value=MagicMock()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
MagicMock(return_value=MagicMock()),
|
||||
)
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(generator, "_prepare_user_inputs", lambda *, user_inputs, **kwargs: user_inputs)
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_workflow_app_generate_entity(**kwargs):
|
||||
captured["workflow_app_generate_entity_kwargs"] = kwargs
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
def fake_generate(**kwargs):
|
||||
captured["application_generate_entity"] = kwargs["application_generate_entity"]
|
||||
return {"data": {}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateEntity", fake_workflow_app_generate_entity
|
||||
)
|
||||
monkeypatch.setattr(generator, "_generate", fake_generate)
|
||||
|
||||
result = generator.generate(
|
||||
app_model=SimpleNamespace(tenant_id="tenant-1", id="app-1"),
|
||||
workflow=SimpleNamespace(features_dict={}),
|
||||
user=SimpleNamespace(id="user-1", session_id="session-1"),
|
||||
args={
|
||||
"inputs": {"query": "hello"},
|
||||
"files": [],
|
||||
"external_trace_id": "trace-1",
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
},
|
||||
},
|
||||
invoke_from="service-api",
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
assert result == {"data": {}}
|
||||
extras = captured["workflow_app_generate_entity_kwargs"]["extras"]
|
||||
assert extras["external_trace_id"] == "trace-1"
|
||||
assert extras["parent_trace_context"].model_dump() == {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
}
|
||||
|
||||
|
||||
def test_resume_delegates_to_generate(mocker: MockerFixture):
|
||||
generator = WorkflowAppGenerator()
|
||||
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
|
||||
from core.workflow.system_variables import SystemVariableKey, build_system_variables
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
@ -217,6 +218,59 @@ class TestWorkflowPersistenceLayer:
|
||||
assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED
|
||||
assert trace_tasks
|
||||
|
||||
def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch):
|
||||
trace_tasks: list[TraceTask] = []
|
||||
trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task))
|
||||
layer, _, _, _ = _make_layer(
|
||||
extras={
|
||||
"external_trace_id": "trace",
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
},
|
||||
},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_workflow_trace(
|
||||
self: TraceTask,
|
||||
*,
|
||||
workflow_run_id: str | None,
|
||||
conversation_id: str | None,
|
||||
user_id: str | None,
|
||||
total_tokens_override: int | None = None,
|
||||
):
|
||||
captured["trace_type"] = self.trace_type
|
||||
captured["external_trace_id"] = self.kwargs.get("external_trace_id")
|
||||
captured["parent_trace_context"] = self.kwargs.get("parent_trace_context")
|
||||
captured["workflow_run_id"] = workflow_run_id
|
||||
return {"ok": True}
|
||||
|
||||
monkeypatch.setattr(TraceTask, "workflow_trace", fake_workflow_trace)
|
||||
|
||||
layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True}))
|
||||
|
||||
assert trace_tasks
|
||||
trace_task = trace_tasks[0]
|
||||
assert trace_task.trace_type == TraceTaskName.WORKFLOW_TRACE
|
||||
assert trace_task.kwargs["external_trace_id"] == "trace"
|
||||
assert trace_task.kwargs["parent_trace_context"] == {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
}
|
||||
|
||||
trace_task.execute()
|
||||
|
||||
assert captured["trace_type"] == TraceTaskName.WORKFLOW_TRACE
|
||||
assert captured["external_trace_id"] == "trace"
|
||||
assert captured["parent_trace_context"] == {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
}
|
||||
|
||||
def test_handle_graph_run_aborted_sets_status(self):
|
||||
layer, exec_repo, _, _ = _make_layer()
|
||||
layer._handle_graph_run_started()
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
|
||||
from core.helper.trace_id_helper import (
|
||||
ParentTraceContext,
|
||||
extract_external_trace_id_from_args,
|
||||
extract_parent_trace_context_from_args,
|
||||
get_external_trace_id,
|
||||
is_valid_trace_id,
|
||||
)
|
||||
|
||||
|
||||
class DummyRequest:
|
||||
@ -84,3 +90,92 @@ class TestTraceIdHelper:
|
||||
def test_extract_external_trace_id_from_args(self, args, expected):
|
||||
"""Test extraction of external_trace_id from args mapping"""
|
||||
assert extract_external_trace_id_from_args(args) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "expected"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": "workflow-run-1",
|
||||
"parent_node_execution_id": "node-execution-1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"parent_trace_context": ParentTraceContext(
|
||||
parent_workflow_run_id="workflow-run-1",
|
||||
parent_node_execution_id="node-execution-1",
|
||||
)
|
||||
},
|
||||
),
|
||||
(
|
||||
{
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": "workflow-run-1",
|
||||
}
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
{
|
||||
"parent_trace_context": {
|
||||
"parent_node_execution_id": "node-execution-1",
|
||||
}
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
{
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": 123,
|
||||
"parent_node_execution_id": "node-execution-1",
|
||||
}
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
{
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": "workflow-run-1",
|
||||
"parent_node_execution_id": None,
|
||||
}
|
||||
},
|
||||
{},
|
||||
),
|
||||
({}, {}),
|
||||
],
|
||||
)
|
||||
def test_extract_parent_trace_context_from_args(self, args, expected):
|
||||
"""Test extraction of parent_trace_context from args mapping"""
|
||||
assert extract_parent_trace_context_from_args(args) == expected
|
||||
|
||||
def test_extract_parent_trace_context_returns_typed_context(self):
|
||||
"""Parent trace context is parsed into a Pydantic value object."""
|
||||
result = extract_parent_trace_context_from_args(
|
||||
{
|
||||
"parent_trace_context": {
|
||||
"parent_workflow_run_id": "workflow-run-1",
|
||||
"parent_node_execution_id": "node-execution-1",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"parent_trace_context": ParentTraceContext(
|
||||
parent_workflow_run_id="workflow-run-1",
|
||||
parent_node_execution_id="node-execution-1",
|
||||
)
|
||||
}
|
||||
|
||||
def test_extract_parent_trace_context_rejects_incomplete_typed_context(self):
|
||||
"""Typed parent trace context follows the same completeness rule as raw mappings."""
|
||||
result = extract_parent_trace_context_from_args(
|
||||
{
|
||||
"parent_trace_context": ParentTraceContext(
|
||||
parent_workflow_run_id="workflow-run-1",
|
||||
parent_node_execution_id=None,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@ -147,6 +147,142 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke
|
||||
assert call_kwargs["pause_state_config"] is None
|
||||
|
||||
|
||||
def test_workflow_tool_passes_parent_trace_context_from_runtime(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure nested workflow runtime metadata is forwarded as parent trace context."""
|
||||
tool = _build_tool()
|
||||
tool.set_parent_trace_context(
|
||||
parent_workflow_run_id="outer-workflow-run-1",
|
||||
parent_node_execution_id="outer-node-execution-1",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
generate_mock = MagicMock(return_value={"data": {}})
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
list(tool.invoke("test_user", {}))
|
||||
|
||||
call_kwargs = generate_mock.call_args.kwargs
|
||||
assert call_kwargs["args"]["parent_trace_context"].model_dump() == {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
}
|
||||
|
||||
|
||||
def test_workflow_tool_keeps_user_inputs_named_like_trace_runtime_keys(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure private trace context does not overwrite same-named workflow inputs."""
|
||||
tool = _build_tool()
|
||||
tool.entity.parameters = [
|
||||
ToolParameter.get_simple_instance(
|
||||
name="outer_workflow_run_id",
|
||||
llm_description="User workflow input",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
),
|
||||
ToolParameter.get_simple_instance(
|
||||
name="outer_node_execution_id",
|
||||
llm_description="User node input",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
tool.set_parent_trace_context(
|
||||
parent_workflow_run_id="outer-workflow-run-1",
|
||||
parent_node_execution_id="outer-node-execution-1",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
generate_mock = MagicMock(return_value={"data": {}})
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
list(
|
||||
tool.invoke(
|
||||
"test_user",
|
||||
{
|
||||
"outer_workflow_run_id": "user-workflow-input",
|
||||
"outer_node_execution_id": "user-node-input",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = generate_mock.call_args.kwargs
|
||||
assert call_kwargs["args"]["inputs"]["outer_workflow_run_id"] == "user-workflow-input"
|
||||
assert call_kwargs["args"]["inputs"]["outer_node_execution_id"] == "user-node-input"
|
||||
assert call_kwargs["args"]["parent_trace_context"].model_dump() == {
|
||||
"parent_workflow_run_id": "outer-workflow-run-1",
|
||||
"parent_node_execution_id": "outer-node-execution-1",
|
||||
}
|
||||
|
||||
|
||||
def test_workflow_tool_can_clear_parent_trace_context(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure reused WorkflowTool instances do not keep stale parent trace context."""
|
||||
tool = _build_tool()
|
||||
tool.set_parent_trace_context(
|
||||
parent_workflow_run_id="outer-workflow-run-1",
|
||||
parent_node_execution_id="outer-node-execution-1",
|
||||
)
|
||||
tool.clear_parent_trace_context()
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
generate_mock = MagicMock(return_value={"data": {}})
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
list(tool.invoke("test_user", {}))
|
||||
|
||||
call_kwargs = generate_mock.call_args.kwargs
|
||||
assert "parent_trace_context" not in call_kwargs["args"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runtime_parameters",
|
||||
[
|
||||
{},
|
||||
{"outer_workflow_run_id": "outer-workflow-run-1"},
|
||||
{"outer_node_execution_id": "outer-node-execution-1"},
|
||||
{"outer_workflow_run_id": None, "outer_node_execution_id": None},
|
||||
],
|
||||
)
|
||||
def test_workflow_tool_omits_parent_trace_context_when_runtime_is_incomplete(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
runtime_parameters: dict[str, Any],
|
||||
):
|
||||
"""Ensure incomplete runtime metadata does not leak parent trace context into generator args."""
|
||||
tool = _build_tool()
|
||||
tool.runtime.runtime_parameters = runtime_parameters
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
generate_mock = MagicMock(return_value={"data": {}})
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
list(tool.invoke("test_user", {}))
|
||||
|
||||
call_kwargs = generate_mock.call_args.kwargs
|
||||
assert "parent_trace_context" not in call_kwargs["args"]
|
||||
|
||||
|
||||
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should generate variable messages when there are outputs"""
|
||||
tool = _build_tool()
|
||||
|
||||
@ -15,9 +15,9 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from graphon.variables.segments import ArrayFileSegment
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - imported for type checking only
|
||||
from graphon.nodes.tool.tool_node import ToolNode
|
||||
@ -106,7 +106,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
|
||||
variable_pool = build_test_variable_pool(variables=build_system_variables(user_id="user-id"))
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
config = graph_config["nodes"][0]
|
||||
@ -234,3 +234,22 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
|
||||
files_segment = completed_events[0].node_run_result.outputs["files"]
|
||||
assert isinstance(files_segment, ArrayFileSegment)
|
||||
assert files_segment.value == [file_obj]
|
||||
|
||||
|
||||
def test_tool_node_passes_node_execution_id_when_runtime_accepts_it(tool_node: ToolNode):
|
||||
runtime_handle = ToolRuntimeHandle(raw=object())
|
||||
tool_node._runtime.get_runtime = MagicMock(return_value=runtime_handle)
|
||||
tool_node.ensure_execution_id = MagicMock(return_value="node-execution-id")
|
||||
|
||||
result = tool_node._get_tool_runtime(
|
||||
variable_pool=tool_node.graph_runtime_state.variable_pool,
|
||||
node_execution_id="node-execution-id",
|
||||
)
|
||||
|
||||
assert result is runtime_handle
|
||||
tool_node._runtime.get_runtime.assert_called_once_with(
|
||||
node_id="node-instance",
|
||||
node_data=tool_node.node_data,
|
||||
variable_pool=tool_node.graph_runtime_state.variable_pool,
|
||||
node_execution_id="node-execution-id",
|
||||
)
|
||||
|
||||
@ -147,6 +147,69 @@ def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: Dify
|
||||
assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN
|
||||
|
||||
|
||||
def test_get_runtime_stores_parent_trace_context_for_workflow_tools(
|
||||
runtime: DifyToolNodeRuntime,
|
||||
) -> None:
|
||||
variable_pool: VariablePool = build_test_variable_pool(
|
||||
variables=build_system_variables(
|
||||
conversation_id="conversation-id",
|
||||
workflow_execution_id="workflow-run-id",
|
||||
)
|
||||
)
|
||||
workflow_runtime = MagicMock()
|
||||
workflow_runtime.runtime.runtime_parameters = {}
|
||||
node_data = ToolNodeData.model_validate(
|
||||
{
|
||||
"type": "tool",
|
||||
"title": "Tool",
|
||||
"provider_id": "provider",
|
||||
"provider_type": ToolProviderType.WORKFLOW,
|
||||
"provider_name": "provider",
|
||||
"tool_name": "lookup",
|
||||
"tool_label": "Lookup",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {},
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_runtime):
|
||||
tool_runtime = runtime.get_runtime(
|
||||
node_id="node-id",
|
||||
node_data=node_data,
|
||||
variable_pool=variable_pool,
|
||||
node_execution_id="node-execution-id",
|
||||
)
|
||||
|
||||
assert tool_runtime.raw.parent_trace_context.model_dump() == {
|
||||
"parent_workflow_run_id": "workflow-run-id",
|
||||
"parent_node_execution_id": "node-execution-id",
|
||||
}
|
||||
assert workflow_runtime.runtime.runtime_parameters == {}
|
||||
|
||||
|
||||
def test_get_runtime_leaves_non_workflow_tool_runtime_parameters_unchanged(
|
||||
runtime: DifyToolNodeRuntime,
|
||||
) -> None:
|
||||
variable_pool: VariablePool = build_test_variable_pool(
|
||||
variables=build_system_variables(
|
||||
conversation_id="conversation-id",
|
||||
workflow_execution_id="workflow-run-id",
|
||||
)
|
||||
)
|
||||
builtin_runtime = MagicMock()
|
||||
builtin_runtime.runtime.runtime_parameters = {}
|
||||
|
||||
with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=builtin_runtime):
|
||||
runtime.get_runtime(
|
||||
node_id="node-id",
|
||||
node_data=_build_tool_node_data(),
|
||||
variable_pool=variable_pool,
|
||||
node_execution_id="node-execution-id",
|
||||
)
|
||||
|
||||
assert builtin_runtime.runtime.runtime_parameters == {}
|
||||
|
||||
|
||||
def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None:
|
||||
tool_runtime = ToolRuntimeHandle(
|
||||
raw=SimpleNamespace(
|
||||
|
||||
@ -316,6 +316,81 @@ def test_dify_tool_file_manager_delegates_file_generator_lookup(monkeypatch: pyt
|
||||
get_file_generator.assert_called_once_with("tool-file-id")
|
||||
|
||||
|
||||
def test_dify_tool_node_runtime_injects_outer_workflow_run_id_for_workflow_tools(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={}))
|
||||
get_runtime = MagicMock(return_value=runtime_tool)
|
||||
monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime)
|
||||
monkeypatch.setattr(
|
||||
node_runtime,
|
||||
"get_system_text",
|
||||
lambda _pool, key: (
|
||||
"outer-workflow-run-id" if key == node_runtime.SystemVariableKey.WORKFLOW_EXECUTION_ID else None
|
||||
),
|
||||
)
|
||||
|
||||
runtime = node_runtime.DifyToolNodeRuntime(_build_run_context())
|
||||
node_data = ToolNodeData(
|
||||
title="Workflow Tool Node",
|
||||
desc=None,
|
||||
provider_id="workflow-provider-id",
|
||||
provider_type=ToolProviderType.WORKFLOW,
|
||||
provider_name="workflow-provider",
|
||||
tool_name="workflow-tool",
|
||||
tool_label="Workflow Tool",
|
||||
tool_configurations={},
|
||||
tool_parameters={},
|
||||
)
|
||||
|
||||
handle = runtime.get_runtime(
|
||||
node_id="tool-node",
|
||||
node_data=node_data,
|
||||
variable_pool=object(),
|
||||
node_execution_id="node-execution-id",
|
||||
)
|
||||
|
||||
assert handle.raw.tool is runtime_tool
|
||||
assert handle.raw.parent_trace_context.model_dump() == {
|
||||
"parent_workflow_run_id": "outer-workflow-run-id",
|
||||
"parent_node_execution_id": "node-execution-id",
|
||||
}
|
||||
assert runtime_tool.runtime.runtime_parameters == {}
|
||||
get_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_dify_tool_node_runtime_does_not_inject_outer_workflow_run_id_for_non_workflow_tools(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={}))
|
||||
get_runtime = MagicMock(return_value=runtime_tool)
|
||||
monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime)
|
||||
monkeypatch.setattr(node_runtime, "get_system_text", lambda _pool, _key: None)
|
||||
|
||||
runtime = node_runtime.DifyToolNodeRuntime(_build_run_context())
|
||||
node_data = ToolNodeData(
|
||||
title="Builtin Tool Node",
|
||||
desc=None,
|
||||
provider_id="builtin-provider-id",
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_name="builtin-provider",
|
||||
tool_name="builtin-tool",
|
||||
tool_label="Builtin Tool",
|
||||
tool_configurations={},
|
||||
tool_parameters={},
|
||||
)
|
||||
|
||||
handle = runtime.get_runtime(
|
||||
node_id="tool-node",
|
||||
node_data=node_data,
|
||||
variable_pool=object(),
|
||||
)
|
||||
|
||||
assert handle.raw.tool is runtime_tool
|
||||
assert "outer_workflow_run_id" not in runtime_tool.runtime.runtime_parameters
|
||||
get_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_dify_human_input_runtime_builds_debug_repository(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repository = MagicMock()
|
||||
repository_cls = MagicMock(return_value=repository)
|
||||
|
||||
@ -4,7 +4,15 @@ from models.comment import WorkflowComment, WorkflowCommentMention, WorkflowComm
|
||||
|
||||
|
||||
def test_workflow_comment_account_properties_and_cache() -> None:
|
||||
comment = WorkflowComment(created_by="user-1", resolved_by="user-2", content="hello", position_x=1, position_y=2)
|
||||
comment = WorkflowComment(
|
||||
created_by="user-1",
|
||||
resolved_by="user-2",
|
||||
content="hello",
|
||||
position_x=1,
|
||||
position_y=2,
|
||||
tenant_id="xxx",
|
||||
app_id="yyy",
|
||||
)
|
||||
created_account = Mock(id="user-1")
|
||||
resolved_account = Mock(id="user-2")
|
||||
|
||||
@ -21,6 +29,8 @@ def test_workflow_comment_account_properties_and_cache() -> None:
|
||||
get_mock.assert_not_called()
|
||||
|
||||
comment_without_resolver = WorkflowComment(
|
||||
tenant_id="xxx",
|
||||
app_id="yyy",
|
||||
created_by="user-1",
|
||||
resolved_by=None,
|
||||
content="hello",
|
||||
@ -37,7 +47,15 @@ def test_workflow_comment_counts_and_participants() -> None:
|
||||
reply_2 = WorkflowCommentReply(comment_id="comment-1", content="reply-2", created_by="user-2")
|
||||
mention_1 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3")
|
||||
mention_2 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-4")
|
||||
comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2)
|
||||
comment = WorkflowComment(
|
||||
created_by="user-1",
|
||||
resolved_by=None,
|
||||
content="hello",
|
||||
position_x=1,
|
||||
position_y=2,
|
||||
tenant_id="xxx",
|
||||
app_id="yyy",
|
||||
)
|
||||
comment.replies = [reply_1, reply_2]
|
||||
comment.mentions = [mention_1, mention_2]
|
||||
|
||||
@ -63,7 +81,15 @@ def test_workflow_comment_counts_and_participants() -> None:
|
||||
def test_workflow_comment_participants_use_cached_accounts() -> None:
|
||||
reply = WorkflowCommentReply(comment_id="comment-1", content="reply-1", created_by="user-2")
|
||||
mention = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3")
|
||||
comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2)
|
||||
comment = WorkflowComment(
|
||||
created_by="user-1",
|
||||
resolved_by=None,
|
||||
content="hello",
|
||||
position_x=1,
|
||||
position_y=2,
|
||||
tenant_id="xxx",
|
||||
app_id="yyy",
|
||||
)
|
||||
comment.replies = [reply]
|
||||
comment.mentions = [mention]
|
||||
|
||||
|
||||
301
api/tests/unit_tests/tasks/test_ops_trace_task.py
Normal file
301
api/tests/unit_tests/tasks/test_ops_trace_task.py
Normal file
@ -0,0 +1,301 @@
|
||||
import json
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from celery.exceptions import Retry
|
||||
|
||||
from core.ops.entities.config_entity import OPS_TRACE_FAILED_KEY
|
||||
from core.ops.exceptions import RetryableTraceDispatchError
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fake_app_context():
|
||||
yield
|
||||
|
||||
|
||||
class FakeCurrentApp:
|
||||
def app_context(self):
|
||||
return fake_app_context()
|
||||
|
||||
|
||||
def _install_trace_manager(
|
||||
trace_instance: MagicMock,
|
||||
*,
|
||||
enterprise_enabled: bool = False,
|
||||
enterprise_trace_cls: MagicMock | None = None,
|
||||
) -> dict[str, ModuleType]:
|
||||
ops_trace_manager_module = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class StubOpsTraceManager:
|
||||
@staticmethod
|
||||
def get_ops_trace_instance(app_id: str) -> MagicMock:
|
||||
return trace_instance
|
||||
|
||||
telemetry_module = ModuleType("extensions.ext_enterprise_telemetry")
|
||||
telemetry_module.is_enabled = lambda: enterprise_enabled
|
||||
|
||||
ops_trace_manager_module.OpsTraceManager = StubOpsTraceManager
|
||||
modules = {
|
||||
"core.ops.ops_trace_manager": ops_trace_manager_module,
|
||||
"extensions.ext_enterprise_telemetry": telemetry_module,
|
||||
}
|
||||
if enterprise_trace_cls is not None:
|
||||
enterprise_module = ModuleType("enterprise")
|
||||
enterprise_telemetry_module = ModuleType("enterprise.telemetry")
|
||||
enterprise_trace_module = ModuleType("enterprise.telemetry.enterprise_trace")
|
||||
enterprise_trace_module.EnterpriseOtelTrace = enterprise_trace_cls
|
||||
modules.update(
|
||||
{
|
||||
"enterprise": enterprise_module,
|
||||
"enterprise.telemetry": enterprise_telemetry_module,
|
||||
"enterprise.telemetry.enterprise_trace": enterprise_trace_module,
|
||||
}
|
||||
)
|
||||
return modules
|
||||
|
||||
|
||||
def _make_payload() -> str:
|
||||
return json.dumps({"trace_info": {}, "trace_info_type": None})
|
||||
|
||||
|
||||
def _decode_saved_payload(payload: bytes | str) -> dict[str, object]:
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode("utf-8")
|
||||
return json.loads(payload)
|
||||
|
||||
|
||||
def _retryable_dispatch_error() -> RetryableTraceDispatchError:
|
||||
return RetryableTraceDispatchError("transient trace dispatch failure")
|
||||
|
||||
|
||||
def _run_task(file_info: dict[str, str], retries: int = 0) -> None:
|
||||
process_trace_tasks.push_request(retries=retries)
|
||||
try:
|
||||
process_trace_tasks.run(file_info)
|
||||
finally:
|
||||
process_trace_tasks.pop_request()
|
||||
|
||||
|
||||
def test_process_trace_tasks_retries_retryable_dispatch_failure_and_preserves_payload():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
pending_error = _retryable_dispatch_error()
|
||||
trace_instance.trace.side_effect = pending_error
|
||||
retry_error = Retry()
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
|
||||
pytest.raises(Retry),
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
mock_retry.assert_called_once_with(
|
||||
exc=pending_error,
|
||||
countdown=process_trace_tasks.default_retry_delay,
|
||||
)
|
||||
mock_delete.assert_not_called()
|
||||
mock_incr.assert_not_called()
|
||||
|
||||
|
||||
def test_process_trace_tasks_marks_enterprise_trace_dispatched_before_retryable_dispatch_retry():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
pending_error = _retryable_dispatch_error()
|
||||
trace_instance.trace.side_effect = pending_error
|
||||
retry_error = Retry()
|
||||
enterprise_tracer = MagicMock()
|
||||
enterprise_trace_cls = MagicMock(return_value=enterprise_tracer)
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
sys.modules,
|
||||
_install_trace_manager(
|
||||
trace_instance,
|
||||
enterprise_enabled=True,
|
||||
enterprise_trace_cls=enterprise_trace_cls,
|
||||
),
|
||||
),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.save") as mock_save,
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
|
||||
pytest.raises(Retry),
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
enterprise_tracer.trace.assert_called_once_with({})
|
||||
saved_path, saved_payload = mock_save.call_args.args
|
||||
assert saved_path == "ops_trace/app-id/file-id.json"
|
||||
assert _decode_saved_payload(saved_payload)["_enterprise_trace_dispatched"] is True
|
||||
mock_retry.assert_called_once_with(
|
||||
exc=pending_error,
|
||||
countdown=process_trace_tasks.default_retry_delay,
|
||||
)
|
||||
mock_delete.assert_not_called()
|
||||
mock_incr.assert_not_called()
|
||||
|
||||
|
||||
def test_process_trace_tasks_does_not_mark_failed_enterprise_trace_as_dispatched_before_retry():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
pending_error = _retryable_dispatch_error()
|
||||
trace_instance.trace.side_effect = pending_error
|
||||
retry_error = Retry()
|
||||
enterprise_tracer = MagicMock()
|
||||
enterprise_tracer.trace.side_effect = RuntimeError("enterprise trace failed")
|
||||
enterprise_trace_cls = MagicMock(return_value=enterprise_tracer)
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
sys.modules,
|
||||
_install_trace_manager(
|
||||
trace_instance,
|
||||
enterprise_enabled=True,
|
||||
enterprise_trace_cls=enterprise_trace_cls,
|
||||
),
|
||||
),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.save") as mock_save,
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
|
||||
pytest.raises(Retry),
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
enterprise_tracer.trace.assert_called_once_with({})
|
||||
mock_save.assert_not_called()
|
||||
mock_retry.assert_called_once_with(
|
||||
exc=pending_error,
|
||||
countdown=process_trace_tasks.default_retry_delay,
|
||||
)
|
||||
mock_delete.assert_not_called()
|
||||
mock_incr.assert_not_called()
|
||||
|
||||
|
||||
def test_process_trace_tasks_skips_enterprise_trace_when_retry_payload_was_already_dispatched():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
enterprise_trace_cls = MagicMock()
|
||||
payload = json.dumps({"trace_info": {}, "trace_info_type": None, "_enterprise_trace_dispatched": True})
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
sys.modules,
|
||||
_install_trace_manager(
|
||||
trace_instance,
|
||||
enterprise_enabled=True,
|
||||
enterprise_trace_cls=enterprise_trace_cls,
|
||||
),
|
||||
),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=payload),
|
||||
patch("tasks.ops_trace_task.storage.save") as mock_save,
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
enterprise_trace_cls.assert_not_called()
|
||||
trace_instance.trace.assert_called_once_with({})
|
||||
mock_save.assert_not_called()
|
||||
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
|
||||
mock_incr.assert_not_called()
|
||||
|
||||
|
||||
def test_process_trace_tasks_default_retry_window_covers_parent_span_context_ttl():
|
||||
assert process_trace_tasks.max_retries * process_trace_tasks.default_retry_delay >= 300
|
||||
|
||||
|
||||
def test_process_trace_tasks_deletes_payload_on_success():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
trace_instance.trace.assert_called_once_with({})
|
||||
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
|
||||
mock_incr.assert_not_called()
|
||||
|
||||
|
||||
def test_process_trace_tasks_deletes_payload_and_counts_terminal_failure():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
trace_instance.trace.side_effect = RuntimeError("trace failed")
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
|
||||
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
|
||||
|
||||
|
||||
def test_process_trace_tasks_treats_retry_enqueue_failure_as_terminal_failure():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
pending_error = _retryable_dispatch_error()
|
||||
retry_enqueue_error = RuntimeError("retry enqueue failed")
|
||||
trace_instance.trace.side_effect = pending_error
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
patch.object(process_trace_tasks, "retry", side_effect=retry_enqueue_error) as mock_retry,
|
||||
):
|
||||
_run_task(file_info)
|
||||
|
||||
mock_retry.assert_called_once_with(
|
||||
exc=pending_error,
|
||||
countdown=process_trace_tasks.default_retry_delay,
|
||||
)
|
||||
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
|
||||
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
|
||||
|
||||
|
||||
def test_process_trace_tasks_deletes_payload_and_counts_exhausted_retryable_dispatch_failure():
|
||||
file_info = {"app_id": "app-id", "file_id": "file-id"}
|
||||
trace_instance = MagicMock()
|
||||
pending_error = _retryable_dispatch_error()
|
||||
trace_instance.trace.side_effect = pending_error
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
|
||||
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
|
||||
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
|
||||
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
|
||||
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
|
||||
patch.object(process_trace_tasks, "retry") as mock_retry,
|
||||
):
|
||||
_run_task(file_info, retries=process_trace_tasks.max_retries)
|
||||
|
||||
mock_retry.assert_not_called()
|
||||
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
|
||||
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
|
||||
@ -1,5 +1,6 @@
|
||||
# ------------------------------------------------------------------
|
||||
# Essential defaults for Docker Compose deployments.
|
||||
# Only include variables required for services to start.
|
||||
#
|
||||
# For a default deployment, copy this file to .env and run:
|
||||
# docker compose up -d
|
||||
|
||||
@ -71,6 +71,8 @@ LOG_TZ=UTC
|
||||
DEBUG=false
|
||||
FLASK_DEBUG=false
|
||||
ENABLE_REQUEST_LOGGING=False
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
|
||||
|
||||
@ -48,16 +48,24 @@ vi.mock('@/context/app-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockSetQuery = vi.fn()
|
||||
const mockSetKeywords = vi.fn()
|
||||
const mockSetTagIDs = vi.fn()
|
||||
const mockSetIsCreatedByMe = vi.fn()
|
||||
const mockSetCategory = vi.fn()
|
||||
const mockQueryState = {
|
||||
category: 'all',
|
||||
tagIDs: [] as string[],
|
||||
keywords: '',
|
||||
isCreatedByMe: false,
|
||||
}
|
||||
vi.mock('../hooks/use-apps-query-state', () => ({
|
||||
default: () => ({
|
||||
isAppListCategory: (value: string) => value === 'all' || Object.values(AppModeEnum).includes(value as AppModeEnum),
|
||||
useAppsQueryState: () => ({
|
||||
query: mockQueryState,
|
||||
setQuery: mockSetQuery,
|
||||
setCategory: mockSetCategory,
|
||||
setKeywords: mockSetKeywords,
|
||||
setTagIDs: mockSetTagIDs,
|
||||
setIsCreatedByMe: mockSetIsCreatedByMe,
|
||||
}),
|
||||
}))
|
||||
|
||||
@ -244,6 +252,7 @@ describe('List', () => {
|
||||
mockServiceState.hasNextPage = false
|
||||
mockServiceState.isLoading = false
|
||||
mockServiceState.isFetchingNextPage = false
|
||||
mockQueryState.category = 'all'
|
||||
mockQueryState.tagIDs = []
|
||||
mockQueryState.keywords = ''
|
||||
mockQueryState.isCreatedByMe = false
|
||||
@ -317,25 +326,21 @@ describe('List', () => {
|
||||
})
|
||||
|
||||
describe('Tab Navigation', () => {
|
||||
it('should update URL when workflow tab is clicked', async () => {
|
||||
const { onUrlUpdate } = renderList()
|
||||
it('should update category when workflow tab is clicked', () => {
|
||||
renderList()
|
||||
|
||||
fireEvent.click(screen.getByText('app.types.workflow'))
|
||||
|
||||
await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(lastCall.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW)
|
||||
expect(mockSetCategory).toHaveBeenCalledWith(AppModeEnum.WORKFLOW)
|
||||
})
|
||||
|
||||
it('should update URL when all tab is clicked', async () => {
|
||||
const { onUrlUpdate } = renderList('?category=workflow')
|
||||
it('should update category when all tab is clicked', () => {
|
||||
mockQueryState.category = AppModeEnum.WORKFLOW
|
||||
renderList()
|
||||
|
||||
fireEvent.click(screen.getByText('app.types.all'))
|
||||
|
||||
await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
// nuqs removes the default value ('all') from URL params
|
||||
expect(lastCall.searchParams.has('category')).toBe(false)
|
||||
expect(mockSetCategory).toHaveBeenCalledWith('all')
|
||||
})
|
||||
})
|
||||
|
||||
@ -351,7 +356,7 @@ describe('List', () => {
|
||||
const input = screen.getByRole('textbox')
|
||||
fireEvent.change(input, { target: { value: 'test search' } })
|
||||
|
||||
expect(mockSetQuery).toHaveBeenCalled()
|
||||
expect(mockSetKeywords).toHaveBeenCalledWith('test search')
|
||||
})
|
||||
|
||||
it('should handle search clear button click', () => {
|
||||
@ -364,7 +369,7 @@ describe('List', () => {
|
||||
if (clearButton)
|
||||
fireEvent.click(clearButton)
|
||||
|
||||
expect(mockSetQuery).toHaveBeenCalled()
|
||||
expect(mockSetKeywords).toHaveBeenCalledWith('')
|
||||
})
|
||||
})
|
||||
|
||||
@ -373,8 +378,9 @@ describe('List', () => {
|
||||
mockQueryState.tagIDs = ['tag-1']
|
||||
mockQueryState.keywords = 'sales'
|
||||
mockQueryState.isCreatedByMe = true
|
||||
mockQueryState.category = AppModeEnum.WORKFLOW
|
||||
|
||||
renderList('?category=workflow')
|
||||
renderList()
|
||||
|
||||
const options = mockAppListInfiniteOptions.mock.calls.at(-1)?.[0] as AppListInfiniteOptions
|
||||
|
||||
@ -412,7 +418,7 @@ describe('List', () => {
|
||||
const checkbox = screen.getByTestId('checkbox-undefined')
|
||||
fireEvent.click(checkbox)
|
||||
|
||||
expect(mockSetQuery).toHaveBeenCalled()
|
||||
expect(mockSetIsCreatedByMe).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
@ -506,8 +512,8 @@ describe('List', () => {
|
||||
expect(screen.getByText('app.types.completion'))!.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update URL for each app type tab click', async () => {
|
||||
const { onUrlUpdate } = renderList()
|
||||
it('should update category for each app type tab click', () => {
|
||||
renderList()
|
||||
|
||||
const appTypeTexts = [
|
||||
{ mode: AppModeEnum.WORKFLOW, text: 'app.types.workflow' },
|
||||
@ -518,11 +524,9 @@ describe('List', () => {
|
||||
]
|
||||
|
||||
for (const { mode, text } of appTypeTexts) {
|
||||
onUrlUpdate.mockClear()
|
||||
mockSetCategory.mockClear()
|
||||
fireEvent.click(screen.getByText(text))
|
||||
await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(lastCall.searchParams.get('category')).toBe(mode)
|
||||
expect(mockSetCategory).toHaveBeenCalledWith(mode)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
1
web/app/components/apps/constants.ts
Normal file
1
web/app/components/apps/constants.ts
Normal file
@ -0,0 +1 @@
|
||||
export const APP_LIST_SEARCH_DEBOUNCE_MS = 500
|
||||
@ -1,6 +1,8 @@
|
||||
import { act, waitFor } from '@testing-library/react'
|
||||
import { renderHookWithNuqs } from '@/test/nuqs-testing'
|
||||
import useAppsQueryState from '../use-apps-query-state'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../../constants'
|
||||
import { useAppsQueryState } from '../use-apps-query-state'
|
||||
|
||||
const renderWithAdapter = (searchParams = '') => {
|
||||
return renderHookWithNuqs(() => useAppsQueryState(), { searchParams })
|
||||
@ -11,214 +13,161 @@ describe('useAppsQueryState', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Initialization', () => {
|
||||
it('should expose query and setQuery when initialized', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
it('should expose app list query state actions', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
|
||||
expect(result.current.query).toBeDefined()
|
||||
expect(typeof result.current.setQuery).toBe('function')
|
||||
expect(result.current.query).toEqual({
|
||||
category: 'all',
|
||||
tagIDs: [],
|
||||
keywords: '',
|
||||
isCreatedByMe: false,
|
||||
})
|
||||
expect(typeof result.current.setCategory).toBe('function')
|
||||
expect(typeof result.current.setKeywords).toBe('function')
|
||||
expect(typeof result.current.setTagIDs).toBe('function')
|
||||
expect(typeof result.current.setIsCreatedByMe).toBe('function')
|
||||
})
|
||||
|
||||
it('should default to empty filters when search params are missing', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
it('should parse app list filters from URL', () => {
|
||||
const { result } = renderWithAdapter(
|
||||
'?category=workflow&tagIDs=tag1;tag2&keywords=search+term&isCreatedByMe=true',
|
||||
)
|
||||
|
||||
expect(result.current.query.tagIDs).toBeUndefined()
|
||||
expect(result.current.query.keywords).toBeUndefined()
|
||||
expect(result.current.query.isCreatedByMe).toBe(false)
|
||||
expect(result.current.query).toEqual({
|
||||
category: AppModeEnum.WORKFLOW,
|
||||
tagIDs: ['tag1', 'tag2'],
|
||||
keywords: 'search term',
|
||||
isCreatedByMe: true,
|
||||
})
|
||||
})
|
||||
|
||||
describe('Parsing search params', () => {
|
||||
it('should parse tagIDs when URL includes tagIDs', () => {
|
||||
const { result } = renderWithAdapter('?tagIDs=tag1;tag2;tag3')
|
||||
it('should update category URL state', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter()
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2', 'tag3'])
|
||||
act(() => {
|
||||
result.current.setCategory(AppModeEnum.WORKFLOW)
|
||||
})
|
||||
|
||||
it('should parse keywords when URL includes keywords', () => {
|
||||
const { result } = renderWithAdapter('?keywords=search+term')
|
||||
|
||||
expect(result.current.query.keywords).toBe('search term')
|
||||
})
|
||||
|
||||
it('should parse isCreatedByMe when URL includes true value', () => {
|
||||
const { result } = renderWithAdapter('?isCreatedByMe=true')
|
||||
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
|
||||
it('should parse all params when URL includes multiple filters', () => {
|
||||
const { result } = renderWithAdapter(
|
||||
'?tagIDs=tag1;tag2&keywords=test&isCreatedByMe=true',
|
||||
)
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
|
||||
expect(result.current.query.keywords).toBe('test')
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(result.current.query.category).toBe(AppModeEnum.WORKFLOW)
|
||||
expect(update.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW)
|
||||
expect(update.options.history).toBe('push')
|
||||
})
|
||||
|
||||
describe('Updating query state', () => {
|
||||
it('should update keywords when setQuery receives keywords', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
it('should remove category from URL when set to all', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter('?category=workflow')
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'new search' })
|
||||
})
|
||||
|
||||
expect(result.current.query.keywords).toBe('new search')
|
||||
act(() => {
|
||||
result.current.setCategory('all')
|
||||
})
|
||||
|
||||
it('should update tagIDs when setQuery receives tagIDs', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
|
||||
})
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
|
||||
})
|
||||
|
||||
it('should update isCreatedByMe when setQuery receives true', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ isCreatedByMe: true })
|
||||
})
|
||||
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
|
||||
it('should support partial updates when setQuery uses callback', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'initial' })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
|
||||
})
|
||||
|
||||
expect(result.current.query.keywords).toBe('initial')
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(result.current.query.category).toBe('all')
|
||||
expect(update.searchParams.has('category')).toBe(false)
|
||||
})
|
||||
|
||||
describe('URL synchronization', () => {
|
||||
it('should sync keywords to URL when keywords change', async () => {
|
||||
it('should update keywords state immediately while debouncing URL writes', async () => {
|
||||
vi.useFakeTimers()
|
||||
try {
|
||||
const { result, onUrlUpdate } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'search' })
|
||||
result.current.setKeywords('search')
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(result.current.query.keywords).toBe('search')
|
||||
expect(onUrlUpdate).not.toHaveBeenCalled()
|
||||
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(APP_LIST_SEARCH_DEBOUNCE_MS + 100)
|
||||
})
|
||||
|
||||
expect(onUrlUpdate).toHaveBeenCalled()
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(update.searchParams.get('keywords')).toBe('search')
|
||||
expect(update.options.history).toBe('push')
|
||||
})
|
||||
}
|
||||
finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
|
||||
it('should sync tagIDs to URL when tagIDs change', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2')
|
||||
})
|
||||
|
||||
it('should sync isCreatedByMe to URL when enabled', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ isCreatedByMe: true })
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(update.searchParams.get('isCreatedByMe')).toBe('true')
|
||||
})
|
||||
|
||||
it('should remove keywords from URL when keywords are cleared', async () => {
|
||||
it('should remove keywords from URL when cleared', async () => {
|
||||
vi.useFakeTimers()
|
||||
try {
|
||||
const { result, onUrlUpdate } = renderWithAdapter('?keywords=existing')
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: '' })
|
||||
result.current.setKeywords('')
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(result.current.query.keywords).toBe('')
|
||||
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(APP_LIST_SEARCH_DEBOUNCE_MS + 100)
|
||||
})
|
||||
|
||||
expect(onUrlUpdate).toHaveBeenCalled()
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(update.searchParams.has('keywords')).toBe(false)
|
||||
})
|
||||
|
||||
it('should remove tagIDs from URL when tagIDs are empty', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2')
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: [] })
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(update.searchParams.has('tagIDs')).toBe(false)
|
||||
})
|
||||
|
||||
it('should remove isCreatedByMe from URL when disabled', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true')
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ isCreatedByMe: false })
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
|
||||
expect(update.searchParams.has('isCreatedByMe')).toBe(false)
|
||||
})
|
||||
}
|
||||
finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should treat empty tagIDs as empty list when URL param is empty', () => {
|
||||
const { result } = renderWithAdapter('?tagIDs=')
|
||||
it('should update tag filter URL state', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter()
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual([])
|
||||
act(() => {
|
||||
result.current.setTagIDs(['tag1', 'tag2'])
|
||||
})
|
||||
|
||||
it('should treat empty keywords as undefined when URL param is empty', () => {
|
||||
const { result } = renderWithAdapter('?keywords=')
|
||||
|
||||
expect(result.current.query.keywords).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should decode keywords with spaces when URL contains encoded spaces', () => {
|
||||
const { result } = renderWithAdapter('?keywords=test+with+spaces')
|
||||
|
||||
expect(result.current.query.keywords).toBe('test with spaces')
|
||||
})
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
|
||||
expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2')
|
||||
expect(update.options.history).toBe('push')
|
||||
})
|
||||
|
||||
describe('Integration scenarios', () => {
|
||||
it('should keep accumulated filters when updates are sequential', () => {
|
||||
const { result } = renderWithAdapter()
|
||||
it('should remove tagIDs from URL when empty', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2')
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'first' })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery(prev => ({ ...prev, tagIDs: ['tag1'] }))
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
|
||||
})
|
||||
|
||||
expect(result.current.query.keywords).toBe('first')
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1'])
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
act(() => {
|
||||
result.current.setTagIDs([])
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(result.current.query.tagIDs).toEqual([])
|
||||
expect(update.searchParams.has('tagIDs')).toBe(false)
|
||||
})
|
||||
|
||||
it('should update created-by-me URL state', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter()
|
||||
|
||||
act(() => {
|
||||
result.current.setIsCreatedByMe(true)
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
expect(update.searchParams.get('isCreatedByMe')).toBe('true')
|
||||
expect(update.options.history).toBe('push')
|
||||
})
|
||||
|
||||
it('should remove isCreatedByMe from URL when disabled', async () => {
|
||||
const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true')
|
||||
|
||||
act(() => {
|
||||
result.current.setIsCreatedByMe(false)
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(result.current.query.isCreatedByMe).toBe(false)
|
||||
expect(update.searchParams.has('isCreatedByMe')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,57 +1,56 @@
|
||||
import { parseAsArrayOf, parseAsBoolean, parseAsString, useQueryStates } from 'nuqs'
|
||||
import { debounce, parseAsArrayOf, parseAsBoolean, parseAsString, parseAsStringLiteral, useQueryStates } from 'nuqs'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { AppModes } from '@/types/app'
|
||||
import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../constants'
|
||||
|
||||
type AppsQuery = {
|
||||
tagIDs?: string[]
|
||||
keywords?: string
|
||||
isCreatedByMe?: boolean
|
||||
const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const
|
||||
export type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number]
|
||||
|
||||
const appListCategorySet = new Set<string>(APP_LIST_CATEGORY_VALUES)
|
||||
|
||||
export const isAppListCategory = (value: string): value is AppListCategory => {
|
||||
return appListCategorySet.has(value)
|
||||
}
|
||||
|
||||
const normalizeKeywords = (value: string | null) => value || undefined
|
||||
|
||||
function useAppsQueryState() {
|
||||
const [urlQuery, setUrlQuery] = useQueryStates(
|
||||
{
|
||||
tagIDs: parseAsArrayOf(parseAsString, ';'),
|
||||
keywords: parseAsString,
|
||||
isCreatedByMe: parseAsBoolean,
|
||||
},
|
||||
{
|
||||
history: 'push',
|
||||
},
|
||||
)
|
||||
|
||||
const query = useMemo<AppsQuery>(() => ({
|
||||
tagIDs: urlQuery.tagIDs ?? undefined,
|
||||
keywords: normalizeKeywords(urlQuery.keywords),
|
||||
isCreatedByMe: urlQuery.isCreatedByMe ?? false,
|
||||
}), [urlQuery.isCreatedByMe, urlQuery.keywords, urlQuery.tagIDs])
|
||||
|
||||
const setQuery = useCallback((next: AppsQuery | ((prev: AppsQuery) => AppsQuery)) => {
|
||||
const buildPatch = (patch: AppsQuery) => {
|
||||
const result: Partial<typeof urlQuery> = {}
|
||||
if ('tagIDs' in patch)
|
||||
result.tagIDs = patch.tagIDs && patch.tagIDs.length > 0 ? patch.tagIDs : null
|
||||
if ('keywords' in patch)
|
||||
result.keywords = patch.keywords ? patch.keywords : null
|
||||
if ('isCreatedByMe' in patch)
|
||||
result.isCreatedByMe = patch.isCreatedByMe ? true : null
|
||||
return result
|
||||
}
|
||||
|
||||
if (typeof next === 'function') {
|
||||
setUrlQuery(prev => buildPatch(next({
|
||||
tagIDs: prev.tagIDs ?? undefined,
|
||||
keywords: normalizeKeywords(prev.keywords),
|
||||
isCreatedByMe: prev.isCreatedByMe ?? false,
|
||||
})))
|
||||
return
|
||||
}
|
||||
|
||||
setUrlQuery(buildPatch(next))
|
||||
}, [setUrlQuery])
|
||||
|
||||
return useMemo(() => ({ query, setQuery }), [query, setQuery])
|
||||
const appListQueryParsers = {
|
||||
category: parseAsStringLiteral(APP_LIST_CATEGORY_VALUES)
|
||||
.withDefault('all')
|
||||
.withOptions({ history: 'push' }),
|
||||
tagIDs: parseAsArrayOf(parseAsString, ';')
|
||||
.withDefault([])
|
||||
.withOptions({ history: 'push' }),
|
||||
keywords: parseAsString.withDefault('').withOptions({
|
||||
limitUrlUpdates: debounce(APP_LIST_SEARCH_DEBOUNCE_MS),
|
||||
}),
|
||||
isCreatedByMe: parseAsBoolean
|
||||
.withDefault(false)
|
||||
.withOptions({ history: 'push' }),
|
||||
}
|
||||
|
||||
export default useAppsQueryState
|
||||
export function useAppsQueryState() {
|
||||
const [query, setQuery] = useQueryStates(appListQueryParsers)
|
||||
|
||||
const setCategory = useCallback((category: AppListCategory) => {
|
||||
setQuery({ category })
|
||||
}, [setQuery])
|
||||
|
||||
const setKeywords = useCallback((keywords: string) => {
|
||||
setQuery({ keywords })
|
||||
}, [setQuery])
|
||||
|
||||
const setTagIDs = useCallback((tagIDs: string[]) => {
|
||||
setQuery({ tagIDs })
|
||||
}, [setQuery])
|
||||
|
||||
const setIsCreatedByMe = useCallback((isCreatedByMe: boolean) => {
|
||||
setQuery({ isCreatedByMe })
|
||||
}, [setQuery])
|
||||
|
||||
return useMemo(() => ({
|
||||
query,
|
||||
setCategory,
|
||||
setKeywords,
|
||||
setTagIDs,
|
||||
setIsCreatedByMe,
|
||||
}), [query, setCategory, setKeywords, setTagIDs, setIsCreatedByMe])
|
||||
}
|
||||
|
||||
@ -4,8 +4,7 @@ import type { FC } from 'react'
|
||||
import type { AppListQuery } from '@/contract/console/apps'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { keepPreviousData, useInfiniteQuery, useSuspenseQuery } from '@tanstack/react-query'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { parseAsStringLiteral, useQueryState } from 'nuqs'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
@ -18,12 +17,13 @@ import { CheckModal } from '@/hooks/use-pay'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { AppModeEnum, AppModes } from '@/types/app'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import AppCard from './app-card'
|
||||
import { AppCardSkeleton } from './app-card-skeleton'
|
||||
import { APP_LIST_SEARCH_DEBOUNCE_MS } from './constants'
|
||||
import Empty from './empty'
|
||||
import Footer from './footer'
|
||||
import useAppsQueryStateHook from './hooks/use-apps-query-state'
|
||||
import { isAppListCategory, useAppsQueryState } from './hooks/use-apps-query-state'
|
||||
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
|
||||
import { useWorkflowOnlineUsers } from './hooks/use-workflow-online-users'
|
||||
import NewAppCard from './new-app-card'
|
||||
@ -35,18 +35,6 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const
|
||||
type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number]
|
||||
const appListCategorySet = new Set<string>(APP_LIST_CATEGORY_VALUES)
|
||||
|
||||
const isAppListCategory = (value: string): value is AppListCategory => {
|
||||
return appListCategorySet.has(value)
|
||||
}
|
||||
|
||||
const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES)
|
||||
.withDefault('all')
|
||||
.withOptions({ history: 'push' })
|
||||
|
||||
type Props = {
|
||||
controlRefreshList?: number
|
||||
}
|
||||
@ -56,28 +44,21 @@ const List: FC<Props> = ({
|
||||
const { t } = useTranslation()
|
||||
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
|
||||
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
|
||||
const [activeTab, setActiveTab] = useQueryState(
|
||||
'category',
|
||||
parseAsAppListCategory,
|
||||
)
|
||||
|
||||
// eslint-disable-next-line react/use-state -- custom URL query hook, not React.useState
|
||||
const appsQuery = useAppsQueryStateHook()
|
||||
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = appsQuery
|
||||
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
|
||||
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
|
||||
const [searchKeywords, setSearchKeywords] = useState(keywords)
|
||||
const {
|
||||
query: { category, tagIDs, keywords, isCreatedByMe },
|
||||
setCategory,
|
||||
setKeywords,
|
||||
setTagIDs,
|
||||
setIsCreatedByMe,
|
||||
} = useAppsQueryState()
|
||||
const debouncedKeywords = useDebounce(keywords, { wait: APP_LIST_SEARCH_DEBOUNCE_MS })
|
||||
const newAppCardRef = useRef<HTMLDivElement>(null)
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const [showTagManagementModal, setShowTagManagementModal] = useState(false)
|
||||
const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false)
|
||||
const [droppedDSLFile, setDroppedDSLFile] = useState<File | undefined>()
|
||||
const setKeywords = useCallback((keywords: string) => {
|
||||
setQuery(prev => ({ ...prev, keywords }))
|
||||
}, [setQuery])
|
||||
const setTagIDs = useCallback((tagIDs: string[]) => {
|
||||
setQuery(prev => ({ ...prev, tagIDs }))
|
||||
}, [setQuery])
|
||||
|
||||
const handleDSLFileDropped = useCallback((file: File) => {
|
||||
setDroppedDSLFile(file)
|
||||
@ -93,11 +74,11 @@ const List: FC<Props> = ({
|
||||
const appListQuery = useMemo<AppListQuery>(() => ({
|
||||
page: 1,
|
||||
limit: 30,
|
||||
name: searchKeywords,
|
||||
name: debouncedKeywords,
|
||||
...(tagIDs.length ? { tag_ids: tagIDs } : {}),
|
||||
...(isCreatedByMe ? { is_created_by_me: isCreatedByMe } : {}),
|
||||
...(activeTab !== 'all' ? { mode: activeTab } : {}),
|
||||
}), [activeTab, isCreatedByMe, searchKeywords, tagIDs])
|
||||
...(category !== 'all' ? { mode: category } : {}),
|
||||
}), [category, debouncedKeywords, isCreatedByMe, tagIDs])
|
||||
|
||||
const {
|
||||
data,
|
||||
@ -177,27 +158,9 @@ const List: FC<Props> = ({
|
||||
return () => observer?.disconnect()
|
||||
}, [isLoading, isFetchingNextPage, fetchNextPage, error, hasNextPage, isCurrentWorkspaceDatasetOperator])
|
||||
|
||||
const { run: handleSearch } = useDebounceFn(() => {
|
||||
setSearchKeywords(keywords)
|
||||
}, { wait: 500 })
|
||||
const handleKeywordsChange = (value: string) => {
|
||||
setKeywords(value)
|
||||
handleSearch()
|
||||
}
|
||||
|
||||
const { run: handleTagsUpdate } = useDebounceFn(() => {
|
||||
setTagIDs(tagFilterValue)
|
||||
}, { wait: 500 })
|
||||
const handleTagsChange = (value: string[]) => {
|
||||
setTagFilterValue(value)
|
||||
handleTagsUpdate()
|
||||
}
|
||||
|
||||
const handleCreatedByMeChange = useCallback(() => {
|
||||
const newValue = !isCreatedByMe
|
||||
setIsCreatedByMe(newValue)
|
||||
setQuery(prev => ({ ...prev, isCreatedByMe: newValue }))
|
||||
}, [isCreatedByMe, setQuery])
|
||||
setIsCreatedByMe(!isCreatedByMe)
|
||||
}, [isCreatedByMe, setIsCreatedByMe])
|
||||
|
||||
const pages = useMemo(() => data?.pages ?? [], [data?.pages])
|
||||
const apps = useMemo(() => pages.flatMap(({ data: pageApps }) => pageApps), [pages])
|
||||
@ -232,10 +195,10 @@ const List: FC<Props> = ({
|
||||
|
||||
<div className="sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pt-7 pb-5">
|
||||
<TabSliderNew
|
||||
value={activeTab}
|
||||
value={category}
|
||||
onChange={(nextValue) => {
|
||||
if (isAppListCategory(nextValue))
|
||||
setActiveTab(nextValue)
|
||||
setCategory(nextValue)
|
||||
}}
|
||||
options={options}
|
||||
/>
|
||||
@ -246,14 +209,14 @@ const List: FC<Props> = ({
|
||||
{t('showMyCreatedAppsOnly', { ns: 'app' })}
|
||||
</div>
|
||||
</label>
|
||||
<TagFilter type="app" value={tagFilterValue} onChange={handleTagsChange} onOpenTagManagement={() => setShowTagManagementModal(true)} />
|
||||
<TagFilter type="app" value={tagIDs} onChange={setTagIDs} onOpenTagManagement={() => setShowTagManagementModal(true)} />
|
||||
<Input
|
||||
showLeftIcon
|
||||
showClearIcon
|
||||
wrapperClassName="w-[200px]"
|
||||
value={keywords}
|
||||
onChange={e => handleKeywordsChange(e.target.value)}
|
||||
onClear={() => handleKeywordsChange('')}
|
||||
onChange={e => setKeywords(e.target.value)}
|
||||
onClear={() => setKeywords('')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@ -267,7 +230,7 @@ const List: FC<Props> = ({
|
||||
ref={newAppCardRef}
|
||||
isLoading={isLoadingCurrentWorkspace}
|
||||
onSuccess={refetch}
|
||||
selectedAppType={activeTab}
|
||||
selectedAppType={category}
|
||||
className={cn(!hasAnyApp && 'z-10')}
|
||||
/>
|
||||
)}
|
||||
|
||||
@ -78,11 +78,11 @@ export const TagFilter = ({
|
||||
!!value.length && 'pr-6 shadow-xs',
|
||||
)}
|
||||
>
|
||||
<span className="flex min-w-0 items-center gap-1">
|
||||
<span className="flex w-full min-w-0 items-center gap-1">
|
||||
<span className="p-px">
|
||||
<Tag01Icon className="h-3.5 w-3.5 text-text-tertiary" aria-hidden="true" />
|
||||
</span>
|
||||
<span className="min-w-0 truncate text-[13px] leading-4.5 text-text-secondary">
|
||||
<span className="min-w-0 grow truncate text-[13px] leading-4.5 text-text-secondary">
|
||||
{!value.length && t('tag.placeholder', { ns: 'common' })}
|
||||
{!!value.length && currentTagName}
|
||||
</span>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user