From c8abb11bf0666eb9a5c888904cd343309ff7fe25 Mon Sep 17 00:00:00 2001 From: Blackoutta <37723456+Blackoutta@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:42:03 +0800 Subject: [PATCH] feat: support custom trace session id for Phoenix tracing (#37056) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/service_api/app/completion.py | 16 +- api/controllers/service_api/app/workflow.py | 13 +- .../app/apps/advanced_chat/app_generator.py | 19 +- api/core/app/apps/advanced_chat/app_runner.py | 3 + api/core/app/apps/agent_app/app_generator.py | 4 +- api/core/app/apps/agent_chat/app_generator.py | 6 +- api/core/app/apps/chat/app_generator.py | 6 +- api/core/app/apps/completion/app_generator.py | 5 +- api/core/app/apps/workflow/app_generator.py | 23 ++- api/core/app/apps/workflow/app_runner.py | 3 + api/core/app/apps/workflow_app_runner.py | 7 + api/core/app/entities/app_invoke_entities.py | 3 + .../easy_ui_based_generate_task_pipeline.py | 5 +- api/core/app/workflow/layers/persistence.py | 3 + api/core/helper/trace_id_helper.py | 64 +++++++ api/core/ops/ops_trace_manager.py | 14 ++ api/core/tools/workflow_as_tool/tool.py | 19 +- api/core/workflow/node_runtime.py | 9 + api/openapi/markdown/service-swagger.md | 3 + .../arize_phoenix_trace.py | 40 +++- .../test_arize_phoenix_trace.py | 81 +++++++- .../test_trace_session_id_parsing.py | 180 ++++++++++++++++++ .../apps/advanced_chat/test_app_generator.py | 6 +- .../test_app_runner_conversation_variables.py | 3 + .../test_app_runner_input_moderation.py | 1 + .../app/apps/agent_app/test_app_generator.py | 36 ++++ .../test_agent_chat_app_generator.py | 4 +- .../chat/test_app_generator_and_runner.py | 8 +- ...est_completion_completion_app_generator.py | 8 +- .../test_trace_session_id_generate_extras.py | 11 ++ .../app/apps/test_workflow_app_generator.py | 2 + .../app/apps/test_workflow_app_runner_core.py | 82 +++++++- .../test_workflow_app_runner_single_node.py | 2 + .../apps/workflow/test_app_generator_extra.py | 94 +++++++++ .../layers/test_pause_state_persist_layer.py | 3 + ...sy_ui_based_generate_task_pipeline_core.py | 12 +- .../app/workflow/test_persistence_layer.py | 4 + .../core/helper/test_trace_id_helper.py | 87 +++++++++ .../core/ops/test_trace_session_metadata.py | 140 ++++++++++++++ .../core/tools/workflow_as_tool/test_tool.py | 52 +++++ .../nodes/tool/test_tool_node_runtime.py | 38 ++++ .../core/workflow/test_node_runtime.py | 40 ++++ .../generated/api/service/types.gen.ts | 3 + .../generated/api/service/zod.gen.ts | 3 + .../develop/template/template.en.mdx | 6 + .../develop/template/template.ja.mdx | 6 + .../develop/template/template.zh.mdx | 6 + .../template/template_advanced_chat.en.mdx | 6 + .../template/template_advanced_chat.ja.mdx | 6 + .../template/template_advanced_chat.zh.mdx | 6 + .../develop/template/template_chat.en.mdx | 6 + .../develop/template/template_chat.ja.mdx | 6 + .../develop/template/template_chat.zh.mdx | 6 + .../develop/template/template_workflow.en.mdx | 10 + .../develop/template/template_workflow.ja.mdx | 10 + .../develop/template/template_workflow.zh.mdx | 10 + 56 files changed, 1214 insertions(+), 35 deletions(-) create mode 100644 api/tests/unit_tests/controllers/service_api/test_trace_session_id_parsing.py create mode 100644 api/tests/unit_tests/core/app/apps/test_trace_session_id_generate_extras.py create mode 100644 api/tests/unit_tests/core/ops/test_trace_session_metadata.py diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index fc5dd269d5..c2294a3fc1 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -28,7 +28,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.helper.trace_id_helper import get_external_trace_id +from core.helper.trace_id_helper import get_external_trace_id, get_trace_session_id, omit_trace_session_id_from_payload from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty @@ -56,6 +56,7 @@ class CompletionRequestPayload(BaseModel): files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None retriever_from: str = Field(default="dev") + trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping") class ChatRequestPayload(BaseModel): @@ -67,6 +68,7 @@ class ChatRequestPayload(BaseModel): retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") + trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping") @field_validator("conversation_id", mode="before") @classmethod @@ -112,9 +114,14 @@ class CompletionApi(Resource): if app_model.mode != AppMode.COMPLETION: raise AppUnavailableError() - payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {}) + payload = CompletionRequestPayload.model_validate( + omit_trace_session_id_from_payload(service_api_ns.payload) or {} + ) external_trace_id = get_external_trace_id(request) args = payload.model_dump(exclude_none=True) + trace_session_id = get_trace_session_id(request) + if trace_session_id: + args["trace_session_id"] = trace_session_id if external_trace_id: args["external_trace_id"] = external_trace_id @@ -209,10 +216,13 @@ class ChatApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}: raise NotChatAppError() - payload = ChatRequestPayload.model_validate(service_api_ns.payload or {}) + payload = ChatRequestPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {}) external_trace_id = get_external_trace_id(request) args = payload.model_dump(exclude_none=True) + trace_session_id = get_trace_session_id(request) + if trace_session_id: + args["trace_session_id"] = trace_session_id if external_trace_id: args["external_trace_id"] = external_trace_id diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index d4bfe874ff..975fdf0cd9 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -30,7 +30,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.helper.trace_id_helper import get_external_trace_id +from core.helper.trace_id_helper import get_external_trace_id, get_trace_session_id, omit_trace_session_id_from_payload from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.base import ResponseModel @@ -54,6 +54,7 @@ logger = logging.getLogger(__name__) class WorkflowRunPayload(WorkflowRunPayloadBase): response_mode: Literal["blocking", "streaming"] | None = None + trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping") class WorkflowLogQuery(BaseModel): @@ -272,8 +273,11 @@ class WorkflowRunApi(Resource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {}) + payload = WorkflowRunPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {}) args = payload.model_dump(exclude_none=True) + trace_session_id = get_trace_session_id(request) + if trace_session_id: + args["trace_session_id"] = trace_session_id external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id @@ -328,8 +332,11 @@ class WorkflowRunByIdApi(Resource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {}) + payload = WorkflowRunPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {}) args = payload.model_dump(exclude_none=True) + trace_session_id = get_trace_session_id(request) + if trace_session_id: + args["trace_session_id"] = trace_session_id # Add workflow_id to args for AppGenerateService args["workflow_id"] = workflow_id diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ae11daebb6..ee7ade9e45 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -40,7 +40,7 @@ from core.app.entities.task_entities import ( ChatbotAppStreamResponse, ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer -from core.helper.trace_id_helper import extract_external_trace_id_from_args +from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_trace_session_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory @@ -64,6 +64,12 @@ from services.workflow_draft_variable_service import ( logger = logging.getLogger(__name__) +def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) -> dict[str, str]: + if isinstance(args, Mapping): + return extract_trace_session_id_from_args(args) + return extract_trace_session_id_from_args({"trace_session_id": getattr(args, "trace_session_id", None)}) + + class AdvancedChatAppGenerator(MessageBasedAppGenerator): _dialogue_count: int @@ -140,6 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras = { "auto_generate_conversation_name": args.get("auto_generate_name", False), **extract_external_trace_id_from_args(args), + **extract_trace_session_id_from_args(args), } # get conversation @@ -331,7 +338,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, - extras={"auto_generate_conversation_name": False}, + extras={ + "auto_generate_conversation_name": False, + **_extract_trace_session_id_from_debug_args(args), + }, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args["inputs"] ), @@ -417,7 +427,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, - extras={"auto_generate_conversation_name": False}, + extras={ + "auto_generate_conversation_name": False, + **_extract_trace_session_id_from_debug_args(args), + }, single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs), ) contexts.plugin_tool_providers.set({}) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index db66e9f592..256521ab65 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -131,6 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, invoke_from=invoke_from, user_from=user_from, + trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: # Handle single iteration or single loop run @@ -139,6 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, user_id=self.application_generate_entity.user_id, + trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) else: inputs = self.application_generate_entity.inputs @@ -199,6 +201,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_from=user_from, invoke_from=invoke_from, root_node_id=root_node_id, + trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) db.session.close() diff --git a/api/core/app/apps/agent_app/app_generator.py b/api/core/app/apps/agent_app/app_generator.py index 7cbe62bcac..467afb891d 100644 --- a/api/core/app/apps/agent_app/app_generator.py +++ b/api/core/app/apps/agent_app/app_generator.py @@ -113,7 +113,9 @@ class AgentAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras={"auto_generate_conversation_name": args.get("auto_generate_name", True)}, + extras={ + "auto_generate_conversation_name": args.get("auto_generate_name", True), + }, call_depth=0, trace_manager=trace_manager, agent_id=agent.id, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 9df959e6f3..9ad724cc89 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -20,6 +20,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom +from core.helper.trace_id_helper import extract_trace_session_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory @@ -96,7 +97,10 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): query = query.replace("\x00", "") inputs = args["inputs"] - extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} + extras = { + "auto_generate_conversation_name": args.get("auto_generate_name", True), + **extract_trace_session_id_from_args(args), + } # get conversation conversation = None diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 254382f4bd..4b8701da8e 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -20,6 +20,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom +from core.helper.trace_id_helper import extract_trace_session_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory @@ -89,7 +90,10 @@ class ChatAppGenerator(MessageBasedAppGenerator): query = query.replace("\x00", "") inputs = args["inputs"] - extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} + extras = { + "auto_generate_conversation_name": args.get("auto_generate_name", True), + **extract_trace_session_id_from_args(args), + } # get conversation conversation = None diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 423bfdac51..9f29b8df29 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -20,6 +20,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom +from core.helper.trace_id_helper import extract_trace_session_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory @@ -148,7 +149,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras={}, + extras={ + **extract_trace_session_id_from_args(args), + }, trace_manager=trace_manager, ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 78255c0512..fb4f94bc87 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -32,7 +32,11 @@ from core.app.entities.task_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory -from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args +from core.helper.trace_id_helper import ( + extract_external_trace_id_from_args, + extract_parent_trace_context_from_args, + extract_trace_session_id_from_args, +) from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository @@ -57,6 +61,12 @@ SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs" logger = logging.getLogger(__name__) +def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) -> dict[str, str]: + if isinstance(args, Mapping): + return extract_trace_session_id_from_args(args) + return extract_trace_session_id_from_args({"trace_session_id": getattr(args, "trace_session_id", None)}) + + class WorkflowAppGenerator(BaseAppGenerator): @staticmethod def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool: @@ -167,6 +177,7 @@ class WorkflowAppGenerator(BaseAppGenerator): extras = { **extract_external_trace_id_from_args(args), **extract_parent_trace_context_from_args(args), + **extract_trace_session_id_from_args(args), } workflow_run_id = str(workflow_run_id or uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args @@ -410,7 +421,10 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, - extras={"auto_generate_conversation_name": False}, + extras={ + "auto_generate_conversation_name": False, + **_extract_trace_session_id_from_debug_args(args), + }, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args["inputs"] ), @@ -496,7 +510,10 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, - extras={"auto_generate_conversation_name": False}, + extras={ + "auto_generate_conversation_name": False, + **_extract_trace_session_id_from_debug_args(args), + }, single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}), workflow_execution_id=str(uuid.uuid4()), ) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9d8a3eb1b1..ecb485885f 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -87,6 +87,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_from=user_from, invoke_from=invoke_from, root_node_id=self._root_node_id, + trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( @@ -94,6 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, user_id=self.application_generate_entity.user_id, + trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) else: inputs = self.application_generate_entity.inputs @@ -128,6 +130,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_from=user_from, invoke_from=invoke_from, root_node_id=root_node_id, + trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index c7af606419..944860ee39 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -118,6 +118,7 @@ class WorkflowBasedAppRunner: tenant_id: str = "", user_id: str = "", root_node_id: str | None = None, + trace_session_id: str | None = None, ) -> Graph: """ Init graph @@ -138,6 +139,7 @@ class WorkflowBasedAppRunner: user_id=user_id, user_from=user_from, invoke_from=invoke_from, + trace_session_id=trace_session_id, ) graph_init_context = DifyGraphInitContext( workflow_id=workflow_id, @@ -171,6 +173,7 @@ class WorkflowBasedAppRunner: single_loop_run: Any | None = None, *, user_id: str, + trace_session_id: str | None = None, ) -> tuple[Graph, VariablePool, GraphRuntimeState]: """ Prepare graph, variable pool, and runtime state for single node execution @@ -208,6 +211,7 @@ class WorkflowBasedAppRunner: node_type_filter_key="iteration_id", node_type_label="iteration", user_id=user_id, + trace_session_id=trace_session_id, ) elif single_loop_run: graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( @@ -218,6 +222,7 @@ class WorkflowBasedAppRunner: node_type_filter_key="loop_id", node_type_label="loop", user_id=user_id, + trace_session_id=trace_session_id, ) else: raise ValueError("Neither single_iteration_run nor single_loop_run is specified") @@ -236,6 +241,7 @@ class WorkflowBasedAppRunner: node_type_label: str = "node", # 'iteration' or 'loop' for error messages *, user_id: str = "", + trace_session_id: str | None = None, ) -> tuple[Graph, VariablePool]: """ Get graph and variable pool for single node execution (iteration or loop). @@ -301,6 +307,7 @@ class WorkflowBasedAppRunner: user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + trace_session_id=trace_session_id, ) graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index debda2da19..08ecc2097b 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -54,6 +54,7 @@ class DifyRunContext(BaseModel): user_id: str user_from: UserFrom invoke_from: InvokeFrom + trace_session_id: str | None = None def build_dify_run_context( @@ -63,6 +64,7 @@ def build_dify_run_context( user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, + trace_session_id: str | None = None, extra_context: Mapping[str, Any] | None = None, ) -> dict[str, Any]: """ @@ -78,6 +80,7 @@ def build_dify_run_context( user_id=user_id, user_from=user_from, invoke_from=invoke_from, + trace_session_id=trace_session_id, ) return run_context diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 171d5ab342..bea50ea269 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -413,7 +413,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id + TraceTaskName.MESSAGE_TRACE, + conversation_id=self._conversation_id, + message_id=self._message_id, + trace_session_id=self._application_generate_entity.extras.get("trace_session_id"), ) ) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 4561388f8b..1708b39b7c 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -417,10 +417,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) external_trace_id = None + trace_session_id = None parent_trace_context = None if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): extras = self._application_generate_entity.extras external_trace_id = extras.get("external_trace_id") + trace_session_id = extras.get("trace_session_id") parent_trace_context = extras.get("parent_trace_context") if isinstance(parent_trace_context, ParentTraceContext): parent_trace_context = parent_trace_context.model_dump(exclude_none=True) @@ -431,6 +433,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id=conversation_id, user_id=self._trace_manager.user_id, external_trace_id=external_trace_id, + trace_session_id=trace_session_id, parent_trace_context=parent_trace_context, ) self._trace_manager.add_trace_task(trace_task) diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index e4890c8d4d..82b5f42885 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError +from werkzeug.exceptions import BadRequest class ParentTraceContext(BaseModel): @@ -72,6 +73,69 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]): return {} +TRACE_SESSION_ID_HEADER = "X-Trace-Session-Id" +TRACE_SESSION_ID_ARG = "trace_session_id" +TRACE_SESSION_ID_MAX_LENGTH = 200 + + +def _validate_trace_session_id(value: Any) -> str: + if not isinstance(value, str): + raise BadRequest("trace_session_id must be a string.") + + normalized = value.strip() + if not normalized: + raise BadRequest("trace_session_id must be 1 to 200 characters after trimming.") + if len(normalized) > TRACE_SESSION_ID_MAX_LENGTH: + raise BadRequest("trace_session_id must be 1 to 200 characters after trimming.") + return normalized + + +def get_trace_session_id(request: Any) -> str | None: + """ + Resolve the Service API trace session ID from explicit request inputs. + + Priority is ``X-Trace-Session-Id`` header, then ``trace_session_id`` query + parameter, then ``trace_session_id`` JSON body field. Only the resolved + highest-priority input is validated; lower-priority values are ignored. + """ + if TRACE_SESSION_ID_HEADER in request.headers: + return _validate_trace_session_id(request.headers.get(TRACE_SESSION_ID_HEADER)) + + if TRACE_SESSION_ID_ARG in request.args: + return _validate_trace_session_id(request.args.get(TRACE_SESSION_ID_ARG)) + + if getattr(request, "is_json", False): + json_data = getattr(request, "json", None) + if isinstance(json_data, Mapping) and TRACE_SESSION_ID_ARG in json_data: + return _validate_trace_session_id(json_data.get(TRACE_SESSION_ID_ARG)) + + return None + + +def extract_trace_session_id_from_args(args: Mapping[str, Any]) -> dict[str, str]: + """ + Extract normalized ``trace_session_id`` from generation args for entity extras. + """ + trace_session_id = args.get(TRACE_SESSION_ID_ARG) + if isinstance(trace_session_id, str): + normalized = trace_session_id.strip() + if normalized: + return {TRACE_SESSION_ID_ARG: normalized} + return {} + + +def omit_trace_session_id_from_payload(payload: Any) -> Any: + """ + Return a payload copy without transport-level ``trace_session_id``. + + Controllers validate this field through :func:`get_trace_session_id` so lower-priority + body values cannot fail DTO validation before header/query priority is applied. + """ + if isinstance(payload, Mapping) and TRACE_SESSION_ID_ARG in payload: + return {key: value for key, value in payload.items() if key != TRACE_SESSION_ID_ARG} + return payload + + def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]: """ Extract 'parent_trace_context' from args. diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 61fd0e5c1f..6afd22a6cc 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -5,6 +5,7 @@ import os import queue import threading import time +from collections.abc import Mapping from datetime import timedelta from typing import TYPE_CHECKING, Any, TypedDict from uuid import UUID, uuid4 @@ -64,6 +65,11 @@ def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | No return None +def _get_trace_session_id(kwargs: Mapping[str, Any]) -> str | None: + value = kwargs.get("trace_session_id") + return value if isinstance(value, str) and value else None + + class _AppTracingConfig(TypedDict, total=False): enabled: bool tracing_provider: str | None @@ -873,6 +879,10 @@ class TraceTask: if dumped_parent_trace_context: metadata["parent_trace_context"] = dumped_parent_trace_context + trace_session_id = _get_trace_session_id(self.kwargs) + if trace_session_id: + metadata["trace_session_id"] = trace_session_id + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -956,6 +966,10 @@ class TraceTask: if node_execution_id := kwargs.get("node_execution_id"): metadata["node_execution_id"] = node_execution_id + trace_session_id = _get_trace_session_id(kwargs) + if trace_session_id: + metadata["trace_session_id"] = trace_session_id + message_tokens = message_data.message_tokens message_trace_info = MessageTraceInfo( diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 3fbd456fe5..dfcdcffca8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -9,7 +9,11 @@ from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory -from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args +from core.helper.trace_id_helper import ( + ParentTraceContext, + extract_parent_trace_context_from_args, + extract_trace_session_id_from_args, +) from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -38,6 +42,7 @@ class WorkflowTool(Tool): """ _parent_trace_context: ParentTraceContext | None + _trace_session_id: str | None def __init__( self, @@ -58,6 +63,7 @@ class WorkflowTool(Tool): self.label = label self._latest_usage = LLMUsage.empty_usage() self._parent_trace_context = None + self._trace_session_id = None super().__init__(entity=entity, runtime=runtime) @@ -103,6 +109,8 @@ class WorkflowTool(Tool): generator_args.update( extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context}) ) + if self._trace_session_id: + generator_args.update(extract_trace_session_id_from_args({"trace_session_id": self._trace_session_id})) result = generator.generate( app_model=app, @@ -215,6 +223,7 @@ class WorkflowTool(Tool): label=self.label, ) forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None + forked._trace_session_id = self._trace_session_id return forked def set_parent_trace_context( @@ -233,6 +242,14 @@ class WorkflowTool(Tool): """Remove parent trace context before invoking this tool outside a nested workflow.""" self._parent_trace_context = None + def set_trace_session_id(self, trace_session_id: str) -> None: + """Attach parent trace session ID without exposing it as tool input.""" + self._trace_session_id = trace_session_id + + def clear_trace_session_id(self) -> None: + """Remove trace session ID before invoking this tool outside a traced session.""" + self._trace_session_id = None + def _resolve_user(self, user_id: str) -> Account | EndUser | None: """ Resolve user object in both HTTP and worker contexts. diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index be14aa92ec..ef3f3d5a6b 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -382,6 +382,7 @@ class _WorkflowToolRuntimeBinding: tool: Tool conversation_id: str | None = None parent_trace_context: ParentTraceContext | None = None + trace_session_id: str | None = None class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): @@ -423,6 +424,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) ) parent_trace_context: ParentTraceContext | None = None + trace_session_id: str | None = None if self._is_workflow_tool_provider(node_data): outer_workflow_run_id = ( None @@ -434,11 +436,14 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): parent_workflow_run_id=outer_workflow_run_id, parent_node_execution_id=node_execution_id, ) + if isinstance(self._run_context.trace_session_id, str) and self._run_context.trace_session_id: + trace_session_id = self._run_context.trace_session_id return ToolRuntimeHandle( raw=_WorkflowToolRuntimeBinding( tool=tool_runtime, conversation_id=conversation_id, parent_trace_context=parent_trace_context, + trace_session_id=trace_session_id, ) ) @@ -471,6 +476,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): ) elif hasattr(tool, "clear_parent_trace_context"): tool.clear_parent_trace_context() + if runtime_binding.trace_session_id and hasattr(tool, "set_trace_session_id"): + tool.set_trace_session_id(runtime_binding.trace_session_id) + elif hasattr(tool, "clear_trace_session_id"): + tool.clear_trace_session_id() try: messages = ToolEngine.generic_invoke( diff --git a/api/openapi/markdown/service-swagger.md b/api/openapi/markdown/service-swagger.md index 3de701f3f3..735d610eff 100644 --- a/api/openapi/markdown/service-swagger.md +++ b/api/openapi/markdown/service-swagger.md @@ -2233,6 +2233,7 @@ Returns a list of available models for the specified model type. | query | string | | Yes | | response_mode | string | *Enum:* `"blocking"`, `"streaming"` | No | | retriever_from | string | | No | +| trace_session_id | string | Trace session ID for observability grouping | No | | workflow_id | string | Workflow ID for advanced chat | No | #### ChildChunkCreatePayload @@ -2293,6 +2294,7 @@ Returns a list of available models for the specified model type. | query | string | | No | | response_mode | string | *Enum:* `"blocking"`, `"streaming"` | No | | retriever_from | string | | No | +| trace_session_id | string | Trace session ID for observability grouping | No | #### Condition @@ -3381,6 +3383,7 @@ Accept the legacy single-tag Service API payload while exposing a normalized tag | files | [ object ] | | No | | inputs | object | | Yes | | response_mode | string | *Enum:* `"blocking"`, `"streaming"` | No | +| trace_session_id | string | Trace session ID for observability grouping | No | #### WorkflowRunResponse diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 9ac0a50a97..f2217dc0a2 100644 --- a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -300,8 +300,12 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN) -def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str: - """Resolve the workflow session ID for Phoenix workflow spans.""" +def _metadata_trace_session_id(trace_info: BaseTraceInfo) -> str | None: + value = trace_info.metadata.get("trace_session_id") + return value if isinstance(value, str) and value else None + + +def _resolve_workflow_session_fallback(trace_info: WorkflowTraceInfo) -> str: if trace_info.conversation_id: return trace_info.conversation_id @@ -312,6 +316,28 @@ def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str: return trace_info.workflow_run_id +def _resolve_message_session_fallback(trace_info: MessageTraceInfo) -> str: + if trace_info.message_data is not None: + conversation_id = getattr(trace_info.message_data, "conversation_id", None) + if conversation_id: + return conversation_id + return "" + + +def _resolve_trace_session_id(trace_info: WorkflowTraceInfo | MessageTraceInfo) -> str: + trace_session_id = _metadata_trace_session_id(trace_info) + if trace_session_id: + return trace_session_id + if isinstance(trace_info, WorkflowTraceInfo): + return _resolve_workflow_session_fallback(trace_info) + return _resolve_message_session_fallback(trace_info) + + +def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str: + """Resolve the workflow session ID for Phoenix workflow spans.""" + return _resolve_trace_session_id(trace_info) + + def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]: """Expose the typed parent context already resolved on the trace info.""" return trace_info.resolved_parent_context @@ -752,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): file_list=safe_json_dumps(file_list), query=trace_info.query or "", ) - workflow_session_id = _resolve_workflow_session_id(trace_info) + workflow_session_id = _resolve_trace_session_id(trace_info) parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info) logger.info( "[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s " @@ -781,6 +807,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.SESSION_ID: workflow_session_id or "", }, } if trace_info.error: @@ -1090,6 +1117,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): model_provider=trace_info.message_data.model_provider or "", model_id=trace_info.message_data.model_id or "", ) + message_session_id = _resolve_trace_session_id(trace_info) # Add end user data if available if trace_info.message_data.from_end_user_id: @@ -1104,7 +1132,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, SpanAttributes.METADATA: safe_json_dumps(metadata), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", + SpanAttributes.SESSION_ID: message_session_id or "", } dify_trace_id = trace_info.trace_id or trace_info.message_id @@ -1129,14 +1157,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance): else: outputs_str = str(trace_info.outputs) - llm_attributes = { + llm_attributes: dict[str, Any] = { SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: outputs_str, SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type, SpanAttributes.METADATA: safe_json_dumps(metadata), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", + SpanAttributes.SESSION_ID: message_session_id or "", } llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 0c62880715..f75bf11530 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -20,6 +20,7 @@ from dify_trace_arize_phoenix.arize_phoenix_trace import ( _resolve_node_parent, _resolve_published_parent_span_context, _resolve_structured_parent_execution_id, + _resolve_trace_session_id, _resolve_workflow_parent_context, _resolve_workflow_session_id, datetime_to_nanos, @@ -106,6 +107,14 @@ def _get_start_span_call(start_span_mock, *, span_name: str): raise AssertionError(f"Could not find start_span call with name={span_name!r}") +def _get_start_span_call_by_kind(start_span_mock, *, span_kind: str): + for call in start_span_mock.call_args_list: + attributes = call.kwargs.get("attributes", {}) + if attributes.get(SpanAttributes.OPENINFERENCE_SPAN_KIND) == span_kind: + return call + raise AssertionError(f"Could not find start_span call with span kind={span_kind!r}") + + class _FakeQuery: def __init__(self, result): self._result = result @@ -358,6 +367,34 @@ class TestGetNodeSpanKind: class TestWorkflowSessionResolution: + def test_resolve_workflow_session_id_prefers_trace_session_id_metadata(self): + trace_info = _make_workflow_info( + conversation_id="conversation-1", + workflow_run_id="workflow-run-1", + metadata={"app_id": "app-1", "trace_session_id": "session-1"}, + ) + + assert _resolve_trace_session_id(trace_info) == "session-1" + assert _resolve_workflow_session_id(trace_info) == "session-1" + + def test_resolve_workflow_session_id_falls_back_to_existing_workflow_behavior(self): + trace_info = _make_workflow_info( + conversation_id="conversation-1", + workflow_run_id="workflow-run-1", + metadata={"app_id": "app-1"}, + ) + + assert _resolve_trace_session_id(trace_info) == "conversation-1" + + def test_resolve_message_session_id_prefers_trace_session_id_metadata(self): + message_data = SimpleNamespace(conversation_id="conversation-1") + trace_info = _make_message_info( + message_data=message_data, + metadata={"app_id": "app-1", "trace_session_id": "session-1"}, + ) + + assert _resolve_trace_session_id(trace_info) == "session-1" + def test_prefers_conversation_id(self): info = _make_workflow_trace_info(conversation_id="conversation-1") @@ -780,7 +817,11 @@ def test_workflow_trace_uses_canonical_root_context_for_top_level_workflow( mock_sessionmaker, mock_repo_factory, mock_db, trace_instance ): mock_db.engine = MagicMock() - info = _make_workflow_info(message_id="message-1", workflow_run_id="workflow-run-1") + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={"app_id": "app1", "trace_session_id": "trace-session-1"}, + ) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_repo_factory.create_workflow_node_execution_repository.return_value = repo @@ -803,6 +844,7 @@ def test_workflow_trace_uses_canonical_root_context_for_top_level_workflow( SpanAttributes.INPUT_MIME_TYPE: "application/json", SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + SpanAttributes.SESSION_ID: "trace-session-1", }, ) mock_extract.assert_called_once_with(carrier=root_carrier) @@ -940,6 +982,7 @@ def test_workflow_trace_reuses_upstream_parent_workflow_context_when_no_parent_n SpanAttributes.INPUT_MIME_TYPE: "application/json", SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + SpanAttributes.SESSION_ID: "outer-workflow-run-1", }, ) mock_extract.assert_called_once_with(carrier=parent_carrier) @@ -1085,6 +1128,7 @@ def test_workflow_trace_falls_back_when_parent_app_tracing_cannot_publish_parent SpanAttributes.INPUT_MIME_TYPE: "application/json", SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + SpanAttributes.SESSION_ID: "outer-workflow-run-1", }, ) mock_extract.assert_called_once_with(carrier=parent_carrier) @@ -1287,6 +1331,7 @@ def test_workflow_trace_keeps_nested_conversation_session_while_reusing_parent_r SpanAttributes.INPUT_MIME_TYPE: "application/json", SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + SpanAttributes.SESSION_ID: "conversation-1", }, ) mock_extract.assert_called_once_with(carrier=parent_carrier) @@ -1926,6 +1971,40 @@ def test_message_trace_keeps_conversation_id_as_session(mock_db, trace_instance) assert message_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-2" +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +def test_message_trace_uses_trace_session_id_metadata_as_session(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info(metadata={"app_id": "app-1", "trace_session_id": "session-1"}) + info.message_data = MagicMock() + info.message_data.conversation_id = "conversation-2" + info.message_data.from_account_id = "acc2" + info.message_data.from_end_user_id = None + info.message_data.query = "q2" + info.message_data.answer = "a2" + info.message_data.status = "s2" + info.message_data.model_id = "m2" + info.message_data.model_provider = "p2" + info.message_data.message_metadata = "{}" + info.message_data.error = None + info.error = None + + root_span = MagicMock() + message_span = MagicMock() + llm_span = MagicMock() + trace_instance.tracer.start_span.side_effect = [root_span, message_span, llm_span] + + trace_instance.message_trace(info) + + message_span_call = _get_start_span_call( + trace_instance.tracer.start_span, span_name=TraceTaskName.MESSAGE_TRACE.value + ) + llm_span_call = _get_start_span_call_by_kind( + trace_instance.tracer.start_span, span_kind=OpenInferenceSpanKindValues.LLM.value + ) + assert message_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "session-1" + assert llm_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "session-1" + + @patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() diff --git a/api/tests/unit_tests/controllers/service_api/test_trace_session_id_parsing.py b/api/tests/unit_tests/controllers/service_api/test_trace_session_id_parsing.py new file mode 100644 index 0000000000..61af8fcfc1 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/test_trace_session_id_parsing.py @@ -0,0 +1,180 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest + +from controllers.service_api.app import completion as completion_module +from controllers.service_api.app import workflow as workflow_module +from core.helper.trace_id_helper import get_trace_session_id +from models.model import AppMode + + +class _Request: + def __init__(self, *, headers=None, args=None, json=None, is_json=True): + self.headers = headers or {} + self.args = args or {} + self.json = json + self.is_json = is_json + + +def test_trace_session_id_header_query_body_priority_matches_service_api_contract(): + req = _Request( + headers={"X-Trace-Session-Id": "header"}, + args={"trace_session_id": "query"}, + json={"trace_session_id": "body"}, + ) + + assert get_trace_session_id(req) == "header" + + +def test_trace_session_id_invalid_highest_priority_raises_bad_request(): + req = _Request( + headers={"X-Trace-Session-Id": " "}, + args={"trace_session_id": "query"}, + json={"trace_session_id": "body"}, + ) + + with pytest.raises(BadRequest): + get_trace_session_id(req) + + +def _app(mode: AppMode) -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode=mode, tenant_id="tenant-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="user-1") + + +def _assert_generate_trace_session_id(mock_generate_service: MagicMock, expected: str) -> None: + _, kwargs = mock_generate_service.generate.call_args + assert kwargs["args"]["trace_session_id"] == expected + + +@patch("controllers.service_api.app.completion.AppGenerateService") +@patch("controllers.service_api.app.completion.service_api_ns") +def test_chat_api_rejects_invalid_highest_priority_query_trace_session_id_without_generating( + mock_service_api_ns: MagicMock, + mock_generate_service: MagicMock, + app: Flask, +): + payload = {"inputs": {}, "query": "hello", "trace_session_id": "body-session"} + mock_service_api_ns.payload = payload + + with app.test_request_context( + "/chat-messages?trace_session_id=%20%20%20", + method="POST", + json=payload, + ): + with pytest.raises(BadRequest): + completion_module.ChatApi().post.__wrapped__( + completion_module.ChatApi(), + _app(AppMode.CHAT), + _end_user(), + ) + + mock_generate_service.generate.assert_not_called() + + +@patch("controllers.service_api.app.workflow.AppGenerateService") +@patch("controllers.service_api.app.workflow.service_api_ns") +def test_workflow_run_api_rejects_invalid_highest_priority_body_trace_session_id_without_generating( + mock_service_api_ns: MagicMock, + mock_generate_service: MagicMock, + app: Flask, +): + payload = {"inputs": {}, "trace_session_id": 123} + mock_service_api_ns.payload = payload + + with app.test_request_context("/workflows/run", method="POST", json=payload): + with pytest.raises(BadRequest): + workflow_module.WorkflowRunApi().post.__wrapped__( + workflow_module.WorkflowRunApi(), + _app(AppMode.WORKFLOW), + _end_user(), + ) + + mock_generate_service.generate.assert_not_called() + + +@patch("controllers.service_api.app.completion.helper.compact_generate_response", return_value={"answer": "ok"}) +@patch("controllers.service_api.app.completion.AppGenerateService") +@patch("controllers.service_api.app.completion.service_api_ns") +def test_completion_api_passes_header_trace_session_id_when_body_value_is_invalid_lower_priority( + mock_service_api_ns: MagicMock, + mock_generate_service: MagicMock, + mock_compact: MagicMock, + app: Flask, +): + payload = {"inputs": {}, "trace_session_id": 123} + mock_service_api_ns.payload = payload + mock_generate_service.generate.return_value = "response" + + with app.test_request_context( + "/completion-messages", + method="POST", + json=payload, + headers={"X-Trace-Session-Id": " header-session "}, + ): + response = completion_module.CompletionApi().post.__wrapped__( + completion_module.CompletionApi(), + _app(AppMode.COMPLETION), + _end_user(), + ) + + assert response == {"answer": "ok"} + _assert_generate_trace_session_id(mock_generate_service, "header-session") + + +@patch("controllers.service_api.app.completion.helper.compact_generate_response", return_value={"answer": "ok"}) +@patch("controllers.service_api.app.completion.AppGenerateService") +@patch("controllers.service_api.app.completion.service_api_ns") +def test_chat_api_passes_query_trace_session_id_when_body_value_is_invalid_lower_priority( + mock_service_api_ns: MagicMock, + mock_generate_service: MagicMock, + mock_compact: MagicMock, + app: Flask, +): + payload = {"inputs": {}, "query": "hello", "trace_session_id": 123} + mock_service_api_ns.payload = payload + mock_generate_service.generate.return_value = "response" + + with app.test_request_context( + "/chat-messages?trace_session_id=query-session", + method="POST", + json=payload, + ): + response = completion_module.ChatApi().post.__wrapped__( + completion_module.ChatApi(), + _app(AppMode.CHAT), + _end_user(), + ) + + assert response == {"answer": "ok"} + _assert_generate_trace_session_id(mock_generate_service, "query-session") + + +@patch("controllers.service_api.app.workflow.helper.compact_generate_response", return_value={"result": "ok"}) +@patch("controllers.service_api.app.workflow.AppGenerateService") +@patch("controllers.service_api.app.workflow.service_api_ns") +def test_workflow_run_api_passes_body_trace_session_id( + mock_service_api_ns: MagicMock, + mock_generate_service: MagicMock, + mock_compact: MagicMock, + app: Flask, +): + payload = {"inputs": {}, "trace_session_id": " body-session "} + mock_service_api_ns.payload = payload + mock_generate_service.generate.return_value = "response" + + with app.test_request_context("/workflows/run", method="POST", json=payload): + response = workflow_module.WorkflowRunApi().post.__wrapped__( + workflow_module.WorkflowRunApi(), + _app(AppMode.WORKFLOW), + _end_user(), + ) + + assert response == {"result": "ok"} + _assert_generate_trace_session_id(mock_generate_service, "body-session") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 5df064030b..b644af083e 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -290,7 +290,7 @@ class TestAdvancedChatAppGeneratorInternals: workflow=workflow, node_id="node-1", user=SimpleNamespace(id="user-id"), - args={"inputs": {"foo": "bar"}}, + args={"inputs": {"foo": "bar"}, "trace_session_id": "session-1"}, streaming=False, ) @@ -298,6 +298,7 @@ class TestAdvancedChatAppGeneratorInternals: assert prefill_calls == [(workflow, "user-id")] assert captured["variable_loader"] is var_loader assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1" + assert captured["application_generate_entity"].extras["trace_session_id"] == "session-1" def test_single_loop_generate_builds_debug_task(self, monkeypatch: pytest.MonkeyPatch): generator = AdvancedChatAppGenerator() @@ -348,7 +349,7 @@ class TestAdvancedChatAppGeneratorInternals: workflow=workflow, node_id="node-2", user=SimpleNamespace(id="user-id"), - args=SimpleNamespace(inputs={"foo": "bar"}), + args=SimpleNamespace(inputs={"foo": "bar"}, trace_session_id="session-1"), streaming=False, ) @@ -356,6 +357,7 @@ class TestAdvancedChatAppGeneratorInternals: assert prefill_calls == [(workflow, "user-id")] assert captured["variable_loader"] is var_loader assert captured["application_generate_entity"].single_loop_run.node_id == "node-2" + assert captured["application_generate_entity"].extras["trace_session_id"] == "session-1" def test_generate_internal_flow_initial_conversation_with_pause_layer(self, monkeypatch: pytest.MonkeyPatch): generator = AdvancedChatAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 370f7abb8b..ac9eddb680 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -99,6 +99,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_app_generate_entity.call_depth = 0 mock_app_generate_entity.single_iteration_run = None mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.extras = {} mock_app_generate_entity.trace_manager = None # Create runner @@ -244,6 +245,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_app_generate_entity.call_depth = 0 mock_app_generate_entity.single_iteration_run = None mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.extras = {} mock_app_generate_entity.trace_manager = None # Create runner @@ -404,6 +406,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_app_generate_entity.call_depth = 0 mock_app_generate_entity.single_iteration_run = None mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.extras = {} mock_app_generate_entity.trace_manager = None # Create runner diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5d8faee897..2076e42e9f 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -63,6 +63,7 @@ def build_runner(): gen.call_depth = 0 gen.single_iteration_run = None gen.single_loop_run = None + gen.extras = {} gen.trace_manager = None runner = AdvancedChatAppRunner( diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py index a80db85416..dcaa31e15e 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_app_generator.py @@ -134,6 +134,42 @@ class TestGenerateSuccess: get_conv.assert_called_once() + def test_generate_does_not_include_trace_session_id_in_extras(self, generator, mocker: MockerFixture): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent") + user = DummyAccount("user") + + generator._resolve_agent = mocker.MagicMock( + return_value=(mocker.MagicMock(id="agent1"), mocker.MagicMock(id="snap1"), mocker.MagicMock()) + ) + generator._prepare_user_inputs = mocker.MagicMock(return_value={}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="raw-response") + + mocker.patch( + f"{MODULE}.AgentAppConfigManager.get_app_config", + return_value=mocker.MagicMock(variables=[], tenant_id="tenant", app_id="app1"), + ) + mocker.patch(f"{MODULE}.ModelConfigConverter.convert", return_value=mocker.MagicMock(model="gpt-4o-mini")) + mocker.patch(f"{MODULE}.TraceQueueManager", return_value=mocker.MagicMock()) + generate_entity = mocker.patch( + f"{MODULE}.AgentAppGenerateEntity", return_value=mocker.MagicMock(task_id="t", user_id="user") + ) + mocker.patch(f"{MODULE}.MessageBasedAppQueueManager", return_value=mocker.MagicMock()) + mocker.patch(f"{MODULE}.threading.Thread", return_value=mocker.MagicMock()) + mocker.patch(f"{MODULE}.AgentAppGenerateResponseConverter.convert", return_value={"result": "ok"}) + + generator.generate( + app_model=app_model, + user=user, + args={"query": "hello", "inputs": {}, "trace_session_id": "session-1"}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert generate_entity.call_args.kwargs["extras"] == {"auto_generate_conversation_name": True} + class TestGenerateWorker: @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 6cd62c933a..9ed7acd39c 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -125,7 +125,7 @@ class TestAgentChatAppGeneratorGenerate: return_value={"result": "ok"}, ) app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from) - mocker.patch( + generate_entity = mocker.patch( "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", return_value=app_entity, ) @@ -136,11 +136,13 @@ class TestAgentChatAppGeneratorGenerate: "conversation_id": "conv", "model_config": {"model": {"provider": "p"}}, "files": [{"id": "f1"}], + "trace_session_id": "session-1", } result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True) assert result == {"result": "ok"} + assert generate_entity.call_args.kwargs["extras"]["trace_session_id"] == "session-1" thread_obj.start.assert_called_once() def test_generate_without_file_config(self, generator, mocker: MockerFixture): diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 8f3c41701b..6f104a5eaa 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -56,7 +56,7 @@ class TestChatAppGenerator: generator = ChatAppGenerator() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") user = SimpleNamespace(id="user-1", session_id="session-1") - args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}} + args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}, "trace_session_id": "session-1"} with ( patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None), @@ -70,7 +70,10 @@ class TestChatAppGenerator: patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()), patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None), patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]), - patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity), + patch( + "core.app.apps.chat.app_generator.ChatAppGenerateEntity", + Mock(side_effect=DummyGenerateEntity), + ) as generate_entity, patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()), patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager), patch( @@ -91,6 +94,7 @@ class TestChatAppGenerator: result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False) assert result == {"ok": True} + assert generate_entity.call_args.kwargs["extras"]["trace_session_id"] == "session-1" def test_generate_rejects_model_config_override_for_non_debugger(self): generator = ChatAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index de20dde677..3acf0d4652 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -30,7 +30,10 @@ def generator(mocker: MockerFixture): mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock()) mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock()) - mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + generate_entity = mocker.patch.object( + module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs) + ) + gen.generate_entity = generate_entity return gen @@ -92,12 +95,13 @@ class TestCompletionAppGenerator: result = generator.generate( app_model=_build_app_model(), user=_build_user(), - args={"query": "q", "inputs": {"a": 1}, "files": []}, + args={"query": "q", "inputs": {"a": 1}, "files": [], "trace_session_id": "session-1"}, invoke_from=InvokeFrom.WEB_APP, streaming=True, ) assert result == "converted" + assert generator.generate_entity.call_args.kwargs["extras"]["trace_session_id"] == "session-1" module.file_factory.build_from_mappings.assert_not_called() def test_generate_success_with_files(self, generator, mocker: MockerFixture): diff --git a/api/tests/unit_tests/core/app/apps/test_trace_session_id_generate_extras.py b/api/tests/unit_tests/core/app/apps/test_trace_session_id_generate_extras.py new file mode 100644 index 0000000000..ef26168975 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_trace_session_id_generate_extras.py @@ -0,0 +1,11 @@ +from core.helper.trace_id_helper import extract_trace_session_id_from_args + + +def test_extract_trace_session_id_from_args_for_generator_extras(): + assert extract_trace_session_id_from_args({"trace_session_id": "session-1"}) == { + "trace_session_id": "session-1", + } + + +def test_extract_trace_session_id_from_args_missing_value_keeps_extras_clean(): + assert extract_trace_session_id_from_args({"inputs": {}}) == {} diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 44c34a0142..154693810b 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -80,6 +80,7 @@ def test_generate_includes_parent_trace_context_in_extras(monkeypatch): "parent_workflow_run_id": "outer-workflow-run-1", "parent_node_execution_id": "outer-node-execution-1", }, + "trace_session_id": "session-1", }, invoke_from="service-api", streaming=False, @@ -93,6 +94,7 @@ def test_generate_includes_parent_trace_context_in_extras(monkeypatch): "parent_workflow_run_id": "outer-workflow-run-1", "parent_node_execution_id": "outer-node-execution-1", } + assert extras["trace_session_id"] == "session-1" def test_resume_delegates_to_generate(mocker: MockerFixture): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 77cd81db58..c463c155a5 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -6,7 +6,7 @@ from types import SimpleNamespace import pytest from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueHumanInputFormFilledEvent, @@ -85,6 +85,35 @@ class TestWorkflowBasedAppRunner: invoke_from=InvokeFrom.DEBUGGER, ) + def test_init_graph_includes_trace_session_id_in_run_context(self, monkeypatch: pytest.MonkeyPatch): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + runtime_state = GraphRuntimeState( + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), + start_at=0.0, + ) + captured = {} + + def fake_from_graph_init_context(**kwargs): + captured["run_context"] = kwargs["graph_init_context"].run_context + return SimpleNamespace() + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.DifyNodeFactory.from_graph_init_context", + fake_from_graph_init_context, + ) + monkeypatch.setattr("core.app.apps.workflow_app_runner.Graph.init", lambda **_kwargs: SimpleNamespace()) + + runner._init_graph( + graph_config={"nodes": [], "edges": []}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + root_node_id="root", + trace_session_id="session-1", + ) + + assert captured["run_context"][DIFY_RUN_CONTEXT_KEY].trace_session_id == "session-1" + def test_prepare_single_node_execution_requires_run(self): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") @@ -145,6 +174,57 @@ class TestWorkflowBasedAppRunner: assert graph is not None assert variable_pool is graph_runtime_state.variable_pool + def test_get_graph_and_variable_pool_for_single_node_run_includes_trace_session_id( + self, monkeypatch: pytest.MonkeyPatch + ): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), + start_at=0.0, + ) + graph_config = { + "nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}], + "edges": [], + } + workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config) + captured = {} + + def fake_from_graph_init_context(**kwargs): + captured["run_context"] = kwargs["graph_init_context"].run_context + return SimpleNamespace() + + class _NodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + from core.app.apps import workflow_app_runner + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.DifyNodeFactory.from_graph_init_context", + fake_from_graph_init_context, + ) + monkeypatch.setattr("core.app.apps.workflow_app_runner.Graph.init", lambda **kwargs: SimpleNamespace()) + monkeypatch.setattr(workflow_app_runner, "resolve_workflow_node_class", lambda **_kwargs: _NodeCls) + monkeypatch.setattr("core.app.apps.workflow_app_runner.load_into_variable_pool", lambda **kwargs: None) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="node-1", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", + user_id="00000000-0000-0000-0000-000000000001", + trace_session_id="session-1", + ) + + assert captured["run_context"][DIFY_RUN_CONTEXT_KEY].trace_session_id == "session-1" + def test_get_graph_and_variable_pool_preloads_constructor_variables_before_graph_init( self, monkeypatch: pytest.MonkeyPatch ): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 248fed5388..4daff81732 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -51,6 +51,7 @@ def test_run_uses_single_node_execution_branch( app_generate_entity.task_id = "task-id" app_generate_entity.call_depth = 0 app_generate_entity.trace_manager = None + app_generate_entity.extras = {"trace_session_id": "session-1"} app_generate_entity.single_iteration_run = single_iteration_run app_generate_entity.single_loop_run = single_loop_run @@ -101,6 +102,7 @@ def test_run_uses_single_node_execution_branch( single_iteration_run=single_iteration_run, single_loop_run=single_loop_run, user_id="user", + trace_session_id="session-1", ) init_graph.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py index 941a47b572..0ad941ce3b 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py @@ -55,6 +55,100 @@ class TestWorkflowAppGeneratorValidation: streaming=False, ) + def test_single_iteration_generate_includes_trace_session_id_in_extras(self, monkeypatch: pytest.MonkeyPatch): + generator = WorkflowAppGenerator() + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + captured: dict[str, object] = {} + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.DraftVarLoader", lambda **kwargs: SimpleNamespace()) + monkeypatch.setattr("core.app.apps.workflow.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace()) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowDraftVariableService", + lambda session: SimpleNamespace(prefill_conversation_variable_default_values=lambda *args, **kwargs: None), + ) + monkeypatch.setattr(generator, "_generate", lambda **kwargs: captured.update(kwargs) or {"ok": True}) + + generator.single_iteration_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(id="workflow-id"), + node_id="node-1", + user=SimpleNamespace(id="user-id"), + args={"inputs": {"foo": "bar"}, "trace_session_id": "session-1"}, + streaming=False, + ) + + assert captured["application_generate_entity"].extras["trace_session_id"] == "session-1" + + def test_single_loop_generate_includes_trace_session_id_in_extras(self, monkeypatch: pytest.MonkeyPatch): + generator = WorkflowAppGenerator() + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + captured: dict[str, object] = {} + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.DraftVarLoader", lambda **kwargs: SimpleNamespace()) + monkeypatch.setattr("core.app.apps.workflow.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace()) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowDraftVariableService", + lambda session: SimpleNamespace(prefill_conversation_variable_default_values=lambda *args, **kwargs: None), + ) + monkeypatch.setattr(generator, "_generate", lambda **kwargs: captured.update(kwargs) or {"ok": True}) + + generator.single_loop_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(id="workflow-id"), + node_id="node-2", + user=SimpleNamespace(id="user-id"), + args=SimpleNamespace(inputs={"foo": "bar"}, trace_session_id="session-1"), + streaming=False, + ) + + assert captured["application_generate_entity"].extras["trace_session_id"] == "session-1" + with pytest.raises(ValueError, match="inputs is required"): generator.single_loop_generate( app_model=SimpleNamespace(), diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 539944d683..3cb180b0cc 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -351,6 +351,7 @@ def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext stream=False, invoke_from=InvokeFrom.DEBUGGER, workflow_execution_id="workflow-exec-roundtrip", + extras={"trace_session_id": "session-1"}, ) ), ) @@ -379,6 +380,7 @@ def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionCo invoke_from=InvokeFrom.DEBUGGER, workflow_run_id="advanced-run-id", query="Explain serialization behavior", + extras={"trace_session_id": "session-1"}, ) ), ) @@ -406,3 +408,4 @@ def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResump assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state restored_entity = loaded.get_generate_entity() assert isinstance(restored_entity, type(state.generate_entity.entity)) + assert restored_entity.extras["trace_session_id"] == "session-1" diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index f10e0084d0..0f2c79f9fc 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -38,6 +38,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk +from core.ops.entities.trace_entity import TraceTaskName from graphon.file import FileTransferMethod from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent @@ -899,8 +900,10 @@ class TestEasyUiBasedGenerateTaskPipeline: def test_save_message_persists_fields_and_emits_trace(self, monkeypatch: pytest.MonkeyPatch): conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + application_generate_entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + application_generate_entity.extras = {"trace_session_id": "session-1"} pipeline = EasyUIBasedGenerateTaskPipeline( - application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + application_generate_entity=application_generate_entity, queue_manager=SimpleNamespace(), conversation=conversation, message=message, @@ -946,7 +949,12 @@ class TestEasyUiBasedGenerateTaskPipeline: assert message_obj.message == "serialized-prompt" assert message_obj.answer == "hello" assert message_obj.provider_response_latency == 5.0 - assert trace_manager.add_trace_task.called + trace_manager.add_trace_task.assert_called_once() + trace_task = trace_manager.add_trace_task.call_args.args[0] + assert trace_task.trace_type == TraceTaskName.MESSAGE_TRACE + assert trace_task.conversation_id == "conv" + assert trace_task.message_id == "msg" + assert trace_task.kwargs["trace_session_id"] == "session-1" assert len(sent_payloads) == 1 def test_save_message_raises_when_message_not_found(self): diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 9cefa97bef..f8e13ca808 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -224,6 +224,7 @@ class TestWorkflowPersistenceLayer: layer, _, _, _ = _make_layer( extras={ "external_trace_id": "trace", + "trace_session_id": "session-1", "parent_trace_context": { "parent_workflow_run_id": "outer-workflow-run-1", "parent_node_execution_id": "outer-node-execution-1", @@ -245,6 +246,7 @@ class TestWorkflowPersistenceLayer: ): captured["trace_type"] = self.trace_type captured["external_trace_id"] = self.kwargs.get("external_trace_id") + captured["trace_session_id"] = self.kwargs.get("trace_session_id") captured["parent_trace_context"] = self.kwargs.get("parent_trace_context") captured["workflow_run_id"] = workflow_run_id return {"ok": True} @@ -257,6 +259,7 @@ class TestWorkflowPersistenceLayer: trace_task = trace_tasks[0] assert trace_task.trace_type == TraceTaskName.WORKFLOW_TRACE assert trace_task.kwargs["external_trace_id"] == "trace" + assert trace_task.kwargs["trace_session_id"] == "session-1" assert trace_task.kwargs["parent_trace_context"] == { "parent_workflow_run_id": "outer-workflow-run-1", "parent_node_execution_id": "outer-node-execution-1", @@ -266,6 +269,7 @@ class TestWorkflowPersistenceLayer: assert captured["trace_type"] == TraceTaskName.WORKFLOW_TRACE assert captured["external_trace_id"] == "trace" + assert captured["trace_session_id"] == "session-1" assert captured["parent_trace_context"] == { "parent_workflow_run_id": "outer-workflow-run-1", "parent_node_execution_id": "outer-node-execution-1", diff --git a/api/tests/unit_tests/core/helper/test_trace_id_helper.py b/api/tests/unit_tests/core/helper/test_trace_id_helper.py index 96e2d44730..59b5eb0870 100644 --- a/api/tests/unit_tests/core/helper/test_trace_id_helper.py +++ b/api/tests/unit_tests/core/helper/test_trace_id_helper.py @@ -1,10 +1,13 @@ import pytest +from werkzeug.exceptions import BadRequest from core.helper.trace_id_helper import ( ParentTraceContext, extract_external_trace_id_from_args, extract_parent_trace_context_from_args, + extract_trace_session_id_from_args, get_external_trace_id, + get_trace_session_id, is_valid_trace_id, ) @@ -17,6 +20,90 @@ class DummyRequest: self.is_json = is_json +class _Request: + def __init__(self, *, headers=None, args=None, json=None, is_json=True): + self.headers = headers or {} + self.args = args or {} + self.json = json + self.is_json = is_json + + +def test_get_trace_session_id_prefers_header_over_query_and_body(): + request = _Request( + headers={"X-Trace-Session-Id": " header-session "}, + args={"trace_session_id": "query-session"}, + json={"trace_session_id": "body-session"}, + ) + + assert get_trace_session_id(request) == "header-session" + + +def test_get_trace_session_id_prefers_query_over_body(): + request = _Request( + args={"trace_session_id": " query-session "}, + json={"trace_session_id": "body-session"}, + ) + + assert get_trace_session_id(request) == "query-session" + + +def test_get_trace_session_id_reads_body_when_no_higher_priority_input(): + request = _Request(json={"trace_session_id": " body/session:123 "}) + + assert get_trace_session_id(request) == "body/session:123" + + +def test_get_trace_session_id_ignores_invalid_lower_priority_value(): + request = _Request( + headers={"X-Trace-Session-Id": "header-session"}, + json={"trace_session_id": " "}, + ) + + assert get_trace_session_id(request) == "header-session" + + +@pytest.mark.parametrize( + "trace_session_request", + [ + _Request(headers={"X-Trace-Session-Id": " "}, json={"trace_session_id": "body-session"}), + _Request(headers={"X-Trace-Session-Id": 123}), + _Request(headers={"X-Trace-Session-Id": "x" * 201}), + ], +) +def test_get_trace_session_id_rejects_invalid_highest_priority_input(trace_session_request): + with pytest.raises(BadRequest) as exc_info: + get_trace_session_id(trace_session_request) + + assert "trace_session_id" in str(exc_info.value) + + +def test_get_trace_session_id_does_not_read_trace_id_or_traceparent(): + request = _Request( + headers={ + "X-Trace-Id": "trace-id", + "traceparent": "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01", + }, + args={"trace_id": "query-trace-id"}, + json={"trace_id": "body-trace-id"}, + ) + + assert get_trace_session_id(request) is None + + +def test_extract_trace_session_id_from_args_returns_trimmed_value(): + args = {"trace_session_id": " session-1 "} + + assert extract_trace_session_id_from_args(args) == {"trace_session_id": "session-1"} + + +def test_extract_trace_session_id_from_args_returns_empty_dict_when_missing(): + assert extract_trace_session_id_from_args({}) == {} + + +def test_extract_trace_session_id_from_args_returns_empty_dict_when_blank_after_trim(): + assert extract_trace_session_id_from_args({"trace_session_id": " "}) == {} + + class TestTraceIdHelper: """Test cases for trace_id_helper.py""" diff --git a/api/tests/unit_tests/core/ops/test_trace_session_metadata.py b/api/tests/unit_tests/core/ops/test_trace_session_metadata.py new file mode 100644 index 0000000000..9a5cb0b65f --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_session_metadata.py @@ -0,0 +1,140 @@ +import json +from datetime import datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceTask + + +class _DummySession: + scalar_values: list[object | None] = [] + + def __init__(self, engine): + self._values = list(self.scalar_values) + self._index = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def execute(self, *args, **kwargs): + return self + + def scalar(self, *args, **kwargs): + if self._index >= len(self._values): + return None + value = self._values[self._index] + self._index += 1 + return value + + def scalars(self, *args, **kwargs): + return self + + def all(self): + return [] + + +def _make_workflow_run(): + return SimpleNamespace( + workflow_id="wf-1", + tenant_id="tenant-1", + id="run-1", + elapsed_time=1, + status="succeeded", + inputs_dict={}, + outputs_dict={}, + version="1", + error=None, + total_tokens=0, + created_at=datetime(2026, 1, 1, 0, 0, 0), + finished_at=datetime(2026, 1, 1, 0, 0, 1), + triggered_from="user", + app_id="app-1", + to_dict=lambda self=None: {"id": "run-1"}, + ) + + +def _make_message_data(): + created_at = datetime(2026, 1, 1, 0, 0, 0) + data = { + "id": "message-1", + "app_id": "app-1", + "conversation_id": "conv-1", + "created_at": created_at, + "updated_at": created_at + timedelta(seconds=1), + "message": "hello", + "provider_response_latency": 1, + "message_tokens": 0, + "answer_tokens": 0, + "answer": "world", + "error": "", + "status": "normal", + "model_provider": "provider", + "model_id": "model", + "from_end_user_id": "end-user-1", + "from_account_id": None, + "agent_based": False, + "workflow_run_id": None, + "from_source": "api", + "message_metadata": json.dumps({"usage": {}}), + } + + class _MessageData: + def __init__(self, values): + self.__dict__.update(values) + + def to_dict(self): + return dict(self.__dict__) + + return _MessageData(data) + + +def test_workflow_trace_metadata_includes_trace_session_id(monkeypatch): + repo = MagicMock() + repo.get_workflow_run_by_id_without_tenant.return_value = _make_workflow_run() + monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo)) + monkeypatch.setattr("core.ops.ops_trace_manager.Session", _DummySession) + monkeypatch.setattr("core.ops.ops_trace_manager.db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr("core.telemetry.gateway.is_enterprise_telemetry_enabled", lambda: False) + _DummySession.scalar_values = [None, None] + + task = TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_execution=SimpleNamespace(id_="run-1", total_tokens=0), + conversation_id="conv-1", + user_id="user-1", + trace_session_id="session-1", + ) + + trace_info = task.workflow_trace(workflow_run_id="run-1", conversation_id="conv-1", user_id="user-1") + + assert task.kwargs["trace_session_id"] == "session-1" + assert trace_info.metadata["trace_session_id"] == "session-1" + + +def test_message_trace_metadata_includes_trace_session_id(monkeypatch): + db_session = MagicMock() + db_session.scalars.return_value.all.return_value = ["chat"] + db_session.scalar.return_value = None + monkeypatch.setattr( + "core.ops.ops_trace_manager.db", + SimpleNamespace(engine=MagicMock(), session=db_session), + ) + monkeypatch.setattr("core.ops.ops_trace_manager.Session", _DummySession) + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda message_id: _make_message_data()) + monkeypatch.setattr("core.telemetry.gateway.is_enterprise_telemetry_enabled", lambda: False) + _DummySession.scalar_values = ["tenant-1"] + + task = TraceTask( + TraceTaskName.MESSAGE_TRACE, + message_id="message-1", + trace_session_id="session-1", + ) + + trace_info = task.message_trace(message_id="message-1", **task.kwargs) + + assert task.kwargs["trace_session_id"] == "session-1" + assert trace_info.metadata["trace_session_id"] == "session-1" diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 6c563b0912..b35df9239c 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -174,6 +174,36 @@ def test_workflow_tool_passes_parent_trace_context_from_runtime(monkeypatch: pyt } +def test_workflow_tool_passes_parent_trace_session_id(monkeypatch: pytest.MonkeyPatch): + """Ensure nested workflows inherit the parent observability session ID.""" + tool = _build_tool() + tool.entity.parameters = [ + ToolParameter.get_simple_instance( + name="trace_session_id", + llm_description="User workflow input", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ), + ] + tool.set_trace_session_id("session-1") + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {"trace_session_id": "user-input-session"})) + + call_kwargs = generate_mock.call_args.kwargs + assert call_kwargs["args"]["inputs"]["trace_session_id"] == "user-input-session" + assert call_kwargs["args"]["trace_session_id"] == "session-1" + + def test_workflow_tool_keeps_user_inputs_named_like_trace_runtime_keys(monkeypatch: pytest.MonkeyPatch): """Ensure private trace context does not overwrite same-named workflow inputs.""" tool = _build_tool() @@ -250,6 +280,28 @@ def test_workflow_tool_can_clear_parent_trace_context(monkeypatch: pytest.Monkey assert "parent_trace_context" not in call_kwargs["args"] +def test_workflow_tool_can_clear_trace_session_id(monkeypatch: pytest.MonkeyPatch): + """Ensure reused WorkflowTool instances do not keep stale trace session IDs.""" + tool = _build_tool() + tool.set_trace_session_id("session-1") + tool.clear_trace_session_id() + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "trace_session_id" not in call_kwargs["args"] + + @pytest.mark.parametrize( "runtime_parameters", [ diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index aece73ce8c..f043837643 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -187,6 +187,44 @@ def test_get_runtime_stores_parent_trace_context_for_workflow_tools( assert workflow_runtime.runtime.runtime_parameters == {} +def test_get_runtime_stores_trace_session_id_for_workflow_tools( + runtime: DifyToolNodeRuntime, +) -> None: + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables( + conversation_id="conversation-id", + workflow_execution_id="workflow-run-id", + ) + ) + workflow_runtime = MagicMock() + workflow_runtime.runtime.runtime_parameters = {} + runtime._run_context.trace_session_id = "session-1" + node_data = ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.WORKFLOW, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_runtime): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=node_data, + variable_pool=variable_pool, + node_execution_id="node-execution-id", + ) + + assert tool_runtime.raw.trace_session_id == "session-1" + assert workflow_runtime.runtime.runtime_parameters == {} + + def test_get_runtime_leaves_non_workflow_tool_runtime_parameters_unchanged( runtime: DifyToolNodeRuntime, ) -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index 244e22a867..216ce513f8 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -470,6 +470,46 @@ def test_dify_tool_node_runtime_injects_outer_workflow_run_id_for_workflow_tools get_runtime.assert_called_once() +def test_dify_tool_node_runtime_stores_trace_session_id_for_workflow_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={})) + get_runtime = MagicMock(return_value=runtime_tool) + monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime) + monkeypatch.setattr( + node_runtime, + "get_system_text", + lambda _pool, key: ( + "outer-workflow-run-id" if key == node_runtime.SystemVariableKey.WORKFLOW_EXECUTION_ID else None + ), + ) + + run_context = _build_run_context() + run_context[DIFY_RUN_CONTEXT_KEY].trace_session_id = "session-1" + runtime = node_runtime.DifyToolNodeRuntime(run_context) + node_data = ToolNodeData( + title="Workflow Tool Node", + desc=None, + provider_id="workflow-provider-id", + provider_type=ToolProviderType.WORKFLOW, + provider_name="workflow-provider", + tool_name="workflow-tool", + tool_label="Workflow Tool", + tool_configurations={}, + tool_parameters={}, + ) + + handle = runtime.get_runtime( + node_id="tool-node", + node_data=node_data, + variable_pool=object(), + node_execution_id="node-execution-id", + ) + + assert handle.raw.trace_session_id == "session-1" + assert runtime_tool.runtime.runtime_parameters == {} + + def test_dify_tool_node_runtime_does_not_inject_outer_workflow_run_id_for_non_workflow_tools( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index 381103b0d5..4cba0beceb 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -57,6 +57,7 @@ export type ChatRequestPayload = { query: string response_mode?: 'blocking' | 'streaming' | null retriever_from?: string + trace_session_id?: string | null workflow_id?: string | null } @@ -107,6 +108,7 @@ export type CompletionRequestPayload = { query?: string response_mode?: 'blocking' | 'streaming' | null retriever_from?: string + trace_session_id?: string | null } export type Condition = { @@ -995,6 +997,7 @@ export type WorkflowRunPayload = { [key: string]: unknown } response_mode?: 'blocking' | 'streaming' | null + trace_session_id?: string | null } export type WorkflowRunResponse = { diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 378d368b12..89f4b0e81e 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -72,6 +72,7 @@ export const zChatRequestPayload = z.object({ query: z.string(), response_mode: z.enum(['blocking', 'streaming']).nullish(), retriever_from: z.string().optional().default('dev'), + trace_session_id: z.string().nullish(), workflow_id: z.string().nullish(), }) @@ -139,6 +140,7 @@ export const zCompletionRequestPayload = z.object({ query: z.string().optional().default(''), response_mode: z.enum(['blocking', 'streaming']).nullish(), retriever_from: z.string().optional().default('dev'), + trace_session_id: z.string().nullish(), }) /** @@ -1351,6 +1353,7 @@ export const zWorkflowRunPayload = z.object({ files: z.array(z.record(z.string(), z.unknown())).nullish(), inputs: z.record(z.string(), z.unknown()), response_mode: z.enum(['blocking', 'streaming']).nullish(), + trace_session_id: z.string().nullish(), }) /** diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx index 4ca27f7f5b..ab8d3cb948 100755 --- a/web/app/components/develop/template/template.en.mdx +++ b/web/app/components/develop/template/template.en.mdx @@ -66,6 +66,12 @@ The text generation application offers non-session support and is ideal for tran - `url` File URL. (Only when transfer method is `remote_url`). - `upload_file_id` Upload file ID. (Only when transfer method is `local_file`). + + (Optional) Trace session ID for observability grouping. Tracing providers that support session grouping can use this value as the exported session identifier. It does not change conversation_id, workflow_run_id, trace_id, or span relationships. Supports the following three ways to pass, in order of priority:
+ - Header: via HTTP Header X-Trace-Session-Id, highest priority.
+ - Query parameter: via URL query parameter trace_session_id.
+ - Request Body: via request body field trace_session_id (i.e., this field).
+
### Response diff --git a/web/app/components/develop/template/template.ja.mdx b/web/app/components/develop/template/template.ja.mdx index b7ebb705f7..2dd311112e 100755 --- a/web/app/components/develop/template/template.ja.mdx +++ b/web/app/components/develop/template/template.ja.mdx @@ -66,6 +66,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `url` ファイルのURL。(転送方法が `remote_url` の場合のみ)。 - `upload_file_id` アップロードされたファイルID。(転送方法が `local_file` の場合のみ)。 + + (オプション)可観測性のグルーピングに使用するトレースセッションID。セッションのグルーピングに対応するトレースプロバイダーは、この値をエクスポートされるセッション識別子として使用できます。conversation_id、workflow_run_id、trace_id、または span の関係は変更しません。以下の3つの方法で渡すことができ、優先順位は次のとおりです:
+ - Header:HTTPヘッダー X-Trace-Session-Id で渡す(最優先)。
+ - クエリパラメータ:URLクエリパラメータ trace_session_id で渡す。
+ - リクエストボディ:リクエストボディの trace_session_id フィールドで渡す(本フィールド)。
+
### レスポンス diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index ae26b23d5c..6c412cc3cb 100755 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -64,6 +64,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `url` 文件地址。(仅当传递方式为 `remote_url` 时)。 - `upload_file_id` 上传文件 ID。(仅当传递方式为 `local_file `时)。 + + (选填)用于可观测性分组的链路会话 ID。支持会话分组的追踪提供商可将该值用作导出的会话标识。它不会改变 conversation_id、workflow_run_id、trace_id 或 span 关系。支持以下三种方式传递,具体优先级依次为:
+ - Header:通过 HTTP Header X-Trace-Session-Id 传递,优先级最高。
+ - Query 参数:通过 URL 查询参数 trace_session_id 传递。
+ - Request Body:通过请求体字段 trace_session_id 传递(即本字段)。
+
### Response diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index d9ee9bcc1e..5bf401b334 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -83,6 +83,12 @@ Chat applications support session persistence, allowing previous chat history to - Query parameter: via URL query parameter trace_id.
- Request Body: via request body field trace_id (i.e., this field).
+ + (Optional) Trace session ID for observability grouping. Tracing providers that support session grouping can use this value as the exported session identifier. It does not change conversation_id, workflow_run_id, trace_id, or span relationships. Supports the following three ways to pass, in order of priority:
+ - Header: via HTTP Header X-Trace-Session-Id, highest priority.
+ - Query parameter: via URL query parameter trace_session_id.
+ - Request Body: via request body field trace_session_id (i.e., this field).
+
### Response diff --git a/web/app/components/develop/template/template_advanced_chat.ja.mdx b/web/app/components/develop/template/template_advanced_chat.ja.mdx index e7189df18c..bc96e326f5 100644 --- a/web/app/components/develop/template/template_advanced_chat.ja.mdx +++ b/web/app/components/develop/template/template_advanced_chat.ja.mdx @@ -83,6 +83,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - クエリパラメータ:URLクエリパラメータ trace_id で渡す。
- リクエストボディ:リクエストボディの trace_id フィールドで渡す(本フィールド)。
+ + (オプション)可観測性のグルーピングに使用するトレースセッションID。セッションのグルーピングに対応するトレースプロバイダーは、この値をエクスポートされるセッション識別子として使用できます。conversation_id、workflow_run_id、trace_id、または span の関係は変更しません。以下の3つの方法で渡すことができ、優先順位は次のとおりです:
+ - Header:HTTPヘッダー X-Trace-Session-Id で渡す(最優先)。
+ - クエリパラメータ:URLクエリパラメータ trace_session_id で渡す。
+ - リクエストボディ:リクエストボディの trace_session_id フィールドで渡す(本フィールド)。
+
### 応答 diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 58d7215a9c..c9a0ed2efe 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -80,6 +80,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - Query 参数:通过 URL 查询参数 trace_id 传递。
- Request Body:通过请求体字段 trace_id 传递(即本字段)。
+ + (选填)用于可观测性分组的链路会话 ID。支持会话分组的追踪提供商可将该值用作导出的会话标识。它不会改变 conversation_id、workflow_run_id、trace_id 或 span 关系。支持以下三种方式传递,具体优先级依次为:
+ - Header:通过 HTTP Header X-Trace-Session-Id 传递,优先级最高。
+ - Query 参数:通过 URL 查询参数 trace_session_id 传递。
+ - Request Body:通过请求体字段 trace_session_id 传递(即本字段)。
+
### Response diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 8567d06e29..a924670e52 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -83,6 +83,12 @@ Chat applications support session persistence, allowing previous chat history to - Query parameter: via URL query parameter trace_id.
- Request Body: via request body field trace_id (i.e., this field).
+ + (Optional) Trace session ID for observability grouping. Tracing providers that support session grouping can use this value as the exported session identifier. It does not change conversation_id, workflow_run_id, trace_id, or span relationships. Supports the following three ways to pass, in order of priority:
+ - Header: via HTTP Header X-Trace-Session-Id, highest priority.
+ - Query parameter: via URL query parameter trace_session_id.
+ - Request Body: via request body field trace_session_id (i.e., this field).
+
### Response diff --git a/web/app/components/develop/template/template_chat.ja.mdx b/web/app/components/develop/template/template_chat.ja.mdx index 5f2e185732..e418f1c8d7 100644 --- a/web/app/components/develop/template/template_chat.ja.mdx +++ b/web/app/components/develop/template/template_chat.ja.mdx @@ -83,6 +83,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - クエリパラメータ:URLクエリパラメータ trace_id で渡す。
- リクエストボディ:リクエストボディの trace_id フィールドで渡す(本フィールド)。
+ + (オプション)可観測性のグルーピングに使用するトレースセッションID。セッションのグルーピングに対応するトレースプロバイダーは、この値をエクスポートされるセッション識別子として使用できます。conversation_id、workflow_run_id、trace_id、または span の関係は変更しません。以下の3つの方法で渡すことができ、優先順位は次のとおりです:
+ - Header:HTTPヘッダー X-Trace-Session-Id で渡す(最優先)。
+ - クエリパラメータ:URLクエリパラメータ trace_session_id で渡す。
+ - リクエストボディ:リクエストボディの trace_session_id フィールドで渡す(本フィールド)。
+
### 応答 diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 2bd36c6b74..1ba4f13a66 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -81,6 +81,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - Query 参数:通过 URL 查询参数 trace_id 传递。
- Request Body:通过请求体字段 trace_id 传递(即本字段)。
+ + (选填)用于可观测性分组的链路会话 ID。支持会话分组的追踪提供商可将该值用作导出的会话标识。它不会改变 conversation_id、workflow_run_id、trace_id 或 span 关系。支持以下三种方式传递,具体优先级依次为:
+ - Header:通过 HTTP Header X-Trace-Session-Id 传递,优先级最高。
+ - Query 参数:通过 URL 查询参数 trace_session_id 传递。
+ - Request Body:通过请求体字段 trace_session_id 传递(即本字段)。
+
### Response diff --git a/web/app/components/develop/template/template_workflow.en.mdx b/web/app/components/develop/template/template_workflow.en.mdx index f37f2cfeb1..a5b2d65100 100644 --- a/web/app/components/develop/template/template_workflow.en.mdx +++ b/web/app/components/develop/template/template_workflow.en.mdx @@ -66,6 +66,11 @@ Workflow applications offers non-session support and is ideal for translation, a 1. Header: via HTTP Header `X-Trace-Id`, highest priority. 2. Query parameter: via URL query parameter `trace_id`. 3. Request Body: via request body field `trace_id` (i.e., this field). + - `trace_session_id` (string) Optional + Trace session ID for observability grouping. Tracing providers that support session grouping can use this value as the exported session identifier. It does not change conversation_id, workflow_run_id, trace_id, or span relationships. Supports the following three ways to pass, in order of priority: + 1. Header: via HTTP Header `X-Trace-Session-Id`, highest priority. + 2. Query parameter: via URL query parameter `trace_session_id`. + 3. Request Body: via request body field `trace_session_id` (i.e., this field). ### Response When `response_mode` is `blocking`, return a CompletionResponse object. @@ -680,6 +685,11 @@ Workflow applications offers non-session support and is ideal for translation, a 1. Header: via HTTP Header `X-Trace-Id`, highest priority. 2. Query parameter: via URL query parameter `trace_id`. 3. Request Body: via request body field `trace_id` (i.e., this field). + - `trace_session_id` (string) Optional + Trace session ID for observability grouping. Tracing providers that support session grouping can use this value as the exported session identifier. It does not change conversation_id, workflow_run_id, trace_id, or span relationships. Supports the following three ways to pass, in order of priority: + 1. Header: via HTTP Header `X-Trace-Session-Id`, highest priority. + 2. Query parameter: via URL query parameter `trace_session_id`. + 3. Request Body: via request body field `trace_session_id` (i.e., this field). ### Response When `response_mode` is `blocking`, return a CompletionResponse object. diff --git a/web/app/components/develop/template/template_workflow.ja.mdx b/web/app/components/develop/template/template_workflow.ja.mdx index 108ebfea5d..881e5de7e7 100644 --- a/web/app/components/develop/template/template_workflow.ja.mdx +++ b/web/app/components/develop/template/template_workflow.ja.mdx @@ -65,6 +65,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 1. Header:HTTPヘッダー `X-Trace-Id` で渡す(最優先)。 2. クエリパラメータ:URLクエリパラメータ `trace_id` で渡す。 3. リクエストボディ:リクエストボディの `trace_id` フィールドで渡す(本フィールド)。 + - `trace_session_id` (string) オプション + 可観測性のグルーピングに使用するトレースセッションID。セッションのグルーピングに対応するトレースプロバイダーは、この値をエクスポートされるセッション識別子として使用できます。conversation_id、workflow_run_id、trace_id、または span の関係は変更しません。以下の3つの方法で渡すことができ、優先順位は次のとおりです: + 1. ヘッダー:HTTP ヘッダー `X-Trace-Session-Id` で渡すことを推奨、最高優先度。 + 2. クエリパラメータ:URL クエリパラメータ `trace_session_id` で渡す。 + 3. リクエストボディ:リクエストボディフィールド `trace_session_id` で渡す(つまり、このフィールド)。 ### 応答 @@ -675,6 +680,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 1. ヘッダー:HTTP ヘッダー `X-Trace-Id` で渡すことを推奨、最高優先度。 2. クエリパラメータ:URL クエリパラメータ `trace_id` で渡す。 3. リクエストボディ:リクエストボディフィールド `trace_id` で渡す(つまり、このフィールド)。 + - `trace_session_id` (string) オプション + 可観測性のグルーピングに使用するトレースセッションID。セッションのグルーピングに対応するトレースプロバイダーは、この値をエクスポートされるセッション識別子として使用できます。conversation_id、workflow_run_id、trace_id、または span の関係は変更しません。以下の3つの方法で渡すことができ、優先順位は以下の通りです: + 1. ヘッダー:HTTP ヘッダー `X-Trace-Session-Id` で渡すことを推奨、最高優先度。 + 2. クエリパラメータ:URL クエリパラメータ `trace_session_id` で渡す。 + 3. リクエストボディ:リクエストボディフィールド `trace_session_id` で渡す(つまり、このフィールド)。 ### 応答 `response_mode` が `blocking` の場合、CompletionResponse オブジェクトを返します。 diff --git a/web/app/components/develop/template/template_workflow.zh.mdx b/web/app/components/develop/template/template_workflow.zh.mdx index eed8736acb..602ff8779d 100644 --- a/web/app/components/develop/template/template_workflow.zh.mdx +++ b/web/app/components/develop/template/template_workflow.zh.mdx @@ -58,6 +58,11 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 1. Header:推荐通过 HTTP Header `X-Trace-Id` 传递,优先级最高。 2. Query 参数:通过 URL 查询参数 `trace_id` 传递。 3. Request Body:通过请求体字段 `trace_id` 传递(即本字段)。 + - `trace_session_id` (string) Optional + 用于可观测性分组的链路会话 ID。支持会话分组的追踪提供商可将该值用作导出的会话标识。它不会改变 conversation_id、workflow_run_id、trace_id 或 span 关系。支持以下三种方式传递,具体优先级依次为: + 1. Header:通过 HTTP Header `X-Trace-Session-Id` 传递,优先级最高。 + 2. Query 参数:通过 URL 查询参数 `trace_session_id` 传递。 + 3. Request Body:通过请求体字段 `trace_session_id` 传递(即本字段)。 ### Response 当 `response_mode` 为 `blocking` 时,返回 CompletionResponse object。 @@ -668,6 +673,11 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 1. Header:推荐通过 HTTP Header `X-Trace-Id` 传递,优先级最高。 2. Query 参数:通过 URL 查询参数 `trace_id` 传递。 3. Request Body:通过请求体字段 `trace_id` 传递(即本字段)。 + - `trace_session_id` (string) Optional + 用于可观测性分组的链路会话 ID。支持会话分组的追踪提供商可将该值用作导出的会话标识。它不会改变 conversation_id、workflow_run_id、trace_id 或 span 关系。支持以下三种方式传递,具体优先级依次为: + 1. Header:通过 HTTP Header `X-Trace-Session-Id` 传递,优先级最高。 + 2. Query 参数:通过 URL 查询参数 `trace_session_id` 传递。 + 3. Request Body:通过请求体字段 `trace_session_id` 传递(即本字段)。 ### Response 当 `response_mode` 为 `blocking` 时,返回 CompletionResponse object。