diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py index e15ede15d2..2ee0a23aab 100644 --- a/api/core/agent/agent_app_runner.py +++ b/api/core/agent/agent_app_runner.py @@ -183,7 +183,24 @@ class AgentAppRunner(BaseAgentRunner): elif output.status == AgentLog.LogStatus.SUCCESS: if output.log_type == AgentLog.LogType.THOUGHT: - pass + if current_agent_thought_id is None: + continue + + thought_text = output.data.get("thought") + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=thought_text, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) elif output.log_type == AgentLog.LogType.TOOL_CALL: if current_agent_thought_id is None: @@ -269,15 +286,20 @@ class AgentAppRunner(BaseAgentRunner): """ Initialize system message """ - if not prompt_messages and prompt_template: - return [ - SystemPromptMessage(content=prompt_template), - ] + if not prompt_template: + return prompt_messages or [] - if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: - prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + prompt_messages = prompt_messages or [] - return prompt_messages or [] + if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = SystemPromptMessage(content=prompt_template) + return prompt_messages + + if not prompt_messages: + return [SystemPromptMessage(content=prompt_template)] + + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + return prompt_messages def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ diff --git a/api/core/agent/patterns/README.md b/api/core/agent/patterns/README.md index f6437ba05a..95b1bf87fa 100644 --- a/api/core/agent/patterns/README.md +++ b/api/core/agent/patterns/README.md @@ -1,67 +1,55 @@ # Agent Patterns -A unified agent pattern module that provides common agent execution strategies for both Agent V2 nodes and Agent Applications in Dify. +A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability. ## Overview -This module implements a strategy pattern for agent execution, automatically selecting the appropriate strategy based on model capabilities. It serves as the core engine for agent-based interactions across different components of the Dify platform. +The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently. ## Key Features -### 1. Multiple Agent Strategies - -- **Function Call Strategy**: Leverages native function/tool calling capabilities of advanced LLMs (e.g., GPT-4, Claude) -- **ReAct Strategy**: Implements the ReAct (Reasoning + Acting) approach for models without native function calling support - -### 2. Automatic Strategy Selection - -The `StrategyFactory` intelligently selects the optimal strategy based on model features: - -- Models with `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL` capabilities → Function Call Strategy -- Other models → ReAct Strategy - -### 3. Unified Interface - -- Common base class (`AgentPattern`) ensures consistent behavior across strategies -- Seamless integration with both workflow nodes and standalone agent applications -- Standardized input/output formats for easy consumption - -### 4. Advanced Capabilities - -- **Streaming Support**: Real-time response streaming for better user experience -- **File Handling**: Built-in support for processing and managing files during agent execution -- **Iteration Control**: Configurable maximum iterations with safety limits (capped at 99) -- **Tool Management**: Flexible tool integration supporting various tool types -- **Context Propagation**: Execution context for tracing, auditing, and debugging +- **Dual strategies** + - `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`. + - `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested. +- **Explicit or auto selection** + - `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT). + - Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not. +- **Unified execution contract** + - `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`. + - Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools. +- **Tool handling and hooks** + - Tools convert to `PromptMessageTool` objects before invocation. + - Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`. + - Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs. +- **File-aware arguments** + - Tool args accept `[File: ]` or `[Files: ]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely. +- **ReAct prompt shaping** + - System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders. + - Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history. +- **Observability and accounting** + - Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths. ## Architecture ``` agent/patterns/ -├── base.py # Abstract base class defining the agent pattern interface -├── function_call.py # Implementation using native LLM function calling -├── react.py # Implementation using ReAct prompting approach -└── strategy_factory.py # Factory for automatic strategy selection +├── base.py # Shared utilities: logging, usage, tool invocation, file handling +├── function_call.py # Native function-calling loop with tool execution +├── react.py # ReAct loop with CoT parsing and scratchpad wiring +└── strategy_factory.py # Strategy selection by model features or explicit override ``` ## Usage -The module is designed to be used by: - -1. **Agent V2 Nodes**: In workflow orchestration for complex agent tasks -1. **Agent Applications**: For standalone conversational agents -1. **Custom Implementations**: As a foundation for building specialized agent behaviors +- For auto-selection: + - Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params. +- For explicit behavior: + - Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct. +- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`. ## Integration Points -- **Model Runtime**: Interfaces with Dify's model runtime for LLM interactions -- **Tool System**: Integrates with the tool framework for external capabilities -- **Memory Management**: Compatible with conversation memory systems -- **File Management**: Handles file inputs/outputs during agent execution - -## Benefits - -1. **Consistency**: Unified implementation reduces code duplication and maintenance overhead -1. **Flexibility**: Easy to extend with new strategies or customize existing ones -1. **Performance**: Optimized for each model's capabilities to ensure best performance -1. **Reliability**: Built-in safety mechanisms and error handling +- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls. +- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers. +- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments. +- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging. diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 6ce33c98ee..3b02683764 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -457,6 +457,9 @@ class WorkflowBasedAppRunner: elif isinstance(event, NodeRunStreamChunkEvent): from core.app.entities.queue_entities import ChunkType as QueueChunkType + if event.is_final and not event.chunk: + return + self._publish_event( QueueTextChunkEvent( text=event.chunk, diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 2a308c6ecd..6cbd48e27b 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,4 +1,5 @@ import logging +import re import time from collections.abc import Generator from threading import Thread @@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] @@ -441,7 +444,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): for thought in agent_thoughts: # Add thought/reasoning if thought.thought: - reasoning_list.append(thought.thought) + reasoning_text = thought.thought + if " blocks and clean the final answer + clean_answer, reasoning_content = self._split_reasoning_from_answer(answer) + if reasoning_content: + answer = clean_answer + llm_result.message.content = clean_answer + llm_result.reasoning_content = reasoning_content + message.answer = clean_answer if reasoning_content: reasoning_list = [reasoning_content] # Content comes first, then reasoning @@ -493,6 +510,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) session.add(generation_detail) + @classmethod + def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]: + """ + Extract reasoning segments from blocks and return (clean_text, reasoning). + """ + matches = cls._THINK_PATTERN.findall(text) + reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" + + clean_text = cls._THINK_PATTERN.sub("", text) + clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() + + return clean_text, reasoning_content or "" + def _handle_stop(self, event: QueueStopEvent): """ Handle stop. diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 0a3189f398..a45d1d1046 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -474,57 +474,67 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) outputs = execution.outputs or {} metadata = execution.metadata or {} - # Extract reasoning_content from outputs - reasoning_content = outputs.get("reasoning_content") - reasoning_list: list[str] = [] - if reasoning_content: - # reasoning_content could be a string or already a list - if isinstance(reasoning_content, str): - reasoning_list = [reasoning_content] if reasoning_content.strip() else [] - elif isinstance(reasoning_content, list): - # Filter out empty or whitespace-only strings - reasoning_list = [r.strip() for r in reasoning_content if isinstance(r, str) and r.strip()] + reasoning_list = self._extract_reasoning(outputs) + tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)) - # Extract tool_calls from metadata.agent_log - tool_calls_list: list[dict] = [] - agent_log = metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG) - if agent_log and isinstance(agent_log, list): - for log in agent_log: - # Each log entry has label, data, status, etc. - log_data = log.data if hasattr(log, "data") else log.get("data", {}) - tool_name = log_data.get("tool_name") - # Only include tool calls with valid tool_name - if tool_name and str(tool_name).strip(): - tool_calls_list.append( - { - "id": log_data.get("tool_call_id", ""), - "name": tool_name, - "arguments": json.dumps(log_data.get("tool_args", {})), - "result": str(log_data.get("output", "")), - } - ) - - # Only save if there's meaningful generation detail (reasoning or tool calls) - has_valid_reasoning = bool(reasoning_list) - has_valid_tool_calls = bool(tool_calls_list) - - if not has_valid_reasoning and not has_valid_tool_calls: + if not reasoning_list and not tool_calls_list: return - # Build sequence based on content, reasoning, and tool_calls - sequence: list[dict] = [] - text = outputs.get("text", "") + sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list) + self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence) - # For now, use a simple sequence: content -> reasoning -> tool_calls - # This can be enhanced later to track actual streaming order + def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]: + """Extract reasoning_content as a clean list of non-empty strings.""" + reasoning_content = outputs.get("reasoning_content") + if isinstance(reasoning_content, str): + trimmed = reasoning_content.strip() + return [trimmed] if trimmed else [] + if isinstance(reasoning_content, list): + return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()] + return [] + + def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]: + """Extract tool call records from agent logs.""" + if not agent_log or not isinstance(agent_log, list): + return [] + + tool_calls: list[dict[str, str]] = [] + for log in agent_log: + log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {}) + tool_name = log_data.get("tool_name") + if tool_name and str(tool_name).strip(): + tool_calls.append( + { + "id": log_data.get("tool_call_id", ""), + "name": tool_name, + "arguments": json.dumps(log_data.get("tool_args", {})), + "result": str(log_data.get("output", "")), + } + ) + return tool_calls + + def _build_generation_sequence( + self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]] + ) -> list[dict[str, Any]]: + """Build a simple content/reasoning/tool_call sequence.""" + sequence: list[dict[str, Any]] = [] if text: sequence.append({"type": "content", "start": 0, "end": len(text)}) - for i, _ in enumerate(reasoning_list): - sequence.append({"type": "reasoning", "index": i}) - for i in range(len(tool_calls_list)): - sequence.append({"type": "tool_call", "index": i}) + for index in range(len(reasoning_list)): + sequence.append({"type": "reasoning", "index": index}) + for index in range(len(tool_calls_list)): + sequence.append({"type": "tool_call", "index": index}) + return sequence - # Check if generation detail already exists for this node execution + def _upsert_generation_detail( + self, + session, + execution: WorkflowNodeExecution, + reasoning_list: list[str], + tool_calls_list: list[dict[str, str]], + sequence: list[dict[str, Any]], + ) -> None: + """Insert or update LLMGenerationDetail with serialized fields.""" existing = ( session.query(LLMGenerationDetail) .filter_by( @@ -534,23 +544,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) .first() ) + reasoning_json = json.dumps(reasoning_list) if reasoning_list else None + tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None + sequence_json = json.dumps(sequence) if sequence else None + if existing: - # Update existing record - existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None - existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None - existing.sequence = json.dumps(sequence) if sequence else None - else: - # Create new record - generation_detail = LLMGenerationDetail( - tenant_id=self._tenant_id, - app_id=self._app_id, - workflow_run_id=execution.workflow_execution_id, - node_id=execution.node_id, - reasoning_content=json.dumps(reasoning_list) if reasoning_list else None, - tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None, - sequence=json.dumps(sequence) if sequence else None, - ) - session.add(generation_detail) + existing.reasoning_content = reasoning_json + existing.tool_calls = tool_calls_json + existing.sequence = sequence_json + return + + generation_detail = LLMGenerationDetail( + tenant_id=self._tenant_id, + app_id=self._app_id, + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + reasoning_content=reasoning_json, + tool_calls=tool_calls_json, + sequence=sequence_json, + ) + session.add(generation_detail) def get_db_models_by_workflow_run( self, diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 631440c6c1..c5ea94ba80 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -391,12 +391,9 @@ class ResponseStreamCoordinator: # Determine which node to attribute the output to # For special selectors (sys, env, conversation), use the active response node # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix + active_session = self._active_session + special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes) + output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix execution_id = self._get_or_create_execution_id(output_node_id) # Check if there's a direct stream for this selector @@ -404,65 +401,27 @@ class ResponseStreamCoordinator: tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams ) - if has_direct_stream: - # Stream all available chunks for direct stream - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, update the event to use active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, + stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector)) + + if stream_targets: + all_complete = True + + for target_selector in stream_targets: + while self._has_unread_stream(target_selector): + if event := self._pop_stream_chunk(target_selector): + events.append( + self._rewrite_stream_event( + event=event, + output_node_id=output_node_id, + execution_id=execution_id, + special_selector=bool(special_selector), + ) ) - events.append(updated_event) - else: - events.append(event) - # Check if stream is closed - if self._is_stream_closed(segment.selector): - is_complete = True + if not self._is_stream_closed(target_selector): + all_complete = False - else: - # No direct stream - check for child field streams (for object types) - child_streams = self._find_child_streams(segment.selector) - - if child_streams: - # Process all child streams - all_children_complete = True - - for child_selector in sorted(child_streams): - # Stream all available chunks from this child - while self._has_unread_stream(child_selector): - if event := self._pop_stream_chunk(child_selector): - # Forward child stream event - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, - chunk_type=event.chunk_type, - tool_call=event.tool_call, - tool_result=event.tool_result, - ) - events.append(updated_event) - else: - events.append(event) - - # Check if this child stream is complete - if not self._is_stream_closed(child_selector): - all_children_complete = False - - # Object segment is complete only when all children are complete - is_complete = all_children_complete + is_complete = all_complete # Fallback: check if scalar value exists in variable pool if not is_complete and not has_direct_stream: @@ -485,6 +444,28 @@ class ResponseStreamCoordinator: return events, is_complete + def _rewrite_stream_event( + self, + event: NodeRunStreamChunkEvent, + output_node_id: str, + execution_id: str, + special_selector: bool, + ) -> NodeRunStreamChunkEvent: + """Rewrite event to attribute to active response node when selector is special.""" + if not special_selector: + return event + + return self._create_stream_chunk_event( + node_id=output_node_id, + execution_id=execution_id, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=event.chunk_type, + tool_call=event.tool_call, + tool_result=event.tool_result, + ) + def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: """Process a text segment. Returns (events, is_complete).""" assert self._active_session is not None diff --git a/api/models/model.py b/api/models/model.py index ba075e2474..32be20e60a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1203,6 +1203,7 @@ class Message(Base): .all() ) + # FIXME (Novice) -- It's easy to cause N+1 query problem here. @property def generation_detail(self) -> dict[str, Any] | None: """ diff --git a/api/models/workflow.py b/api/models/workflow.py index bc229fb4e4..5131177836 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -695,6 +695,10 @@ class WorkflowRun(Base): def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + @property + def outputs_as_generation(self): + return is_generation_outputs(self.outputs_dict) + def to_dict(self): return { "id": self.id, @@ -708,7 +712,7 @@ class WorkflowRun(Base): "inputs": self.inputs_dict, "status": self.status, "outputs": self.outputs_dict, - "outputs_as_generation": is_generation_outputs(self.outputs_dict), + "outputs_as_generation": self.outputs_as_generation, "error": self.error, "elapsed_time": self.elapsed_time, "total_tokens": self.total_tokens, diff --git a/api/tests/unit_tests/core/agent/__init__.py b/api/tests/unit_tests/core/agent/__init__.py index a9ccd45f4b..e7c478bf83 100644 --- a/api/tests/unit_tests/core/agent/__init__.py +++ b/api/tests/unit_tests/core/agent/__init__.py @@ -1,3 +1,4 @@ """ Mark agent test modules as a package to avoid import name collisions. """ + diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py new file mode 100644 index 0000000000..6a8a691a25 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py @@ -0,0 +1,48 @@ +from unittest.mock import MagicMock + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.workflow.graph_events import NodeRunStreamChunkEvent +from core.workflow.nodes import NodeType + + +class DummyQueueManager: + def __init__(self) -> None: + self.published = [] + + def publish(self, event, publish_from: PublishFrom) -> None: + self.published.append((event, publish_from)) + + +def test_skip_empty_final_chunk() -> None: + queue_manager = DummyQueueManager() + runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app") + + empty_final_event = NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + selector=["node", "text"], + chunk="", + is_final=True, + ) + + runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event) + assert queue_manager.published == [] + + normal_event = NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + selector=["node", "text"], + chunk="hi", + is_final=False, + ) + + runner._handle_event(workflow_entry=MagicMock(), event=normal_event) + + assert len(queue_manager.published) == 1 + published_event, publish_from = queue_manager.published[0] + assert publish_from == PublishFrom.APPLICATION_MANAGER + assert published_event.text == "hi" + diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py index c6b3797ce2..822b6a808f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py @@ -6,6 +6,7 @@ from core.workflow.entities.tool_entities import ToolResultStatus from core.workflow.enums import NodeType from core.workflow.graph.graph import Graph from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from core.workflow.graph_engine.response_coordinator.session import ResponseSession from core.workflow.graph_events import ( ChunkType, NodeRunStreamChunkEvent, @@ -13,6 +14,7 @@ from core.workflow.graph_events import ( ToolResult, ) from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.template import Template, VariableSegment from core.workflow.runtime import VariablePool @@ -186,3 +188,44 @@ class TestResponseCoordinatorObjectStreaming: assert ("node1", "generation", "content") in children assert ("node1", "generation", "tool_calls") in children assert ("node1", "generation", "thought") in children + + def test_special_selector_rewrites_to_active_response_node(self): + """Ensure special selectors attribute streams to the active response node.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + response_node = MagicMock() + response_node.id = "response_node" + response_node.node_type = NodeType.ANSWER + graph.nodes = {"response_node": response_node} + graph.root_node = response_node + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + coordinator.track_node_execution("response_node", "exec_resp") + + coordinator._active_session = ResponseSession( + node_id="response_node", + template=Template(segments=[VariableSegment(selector=["sys", "foo"])]), + ) + + event = NodeRunStreamChunkEvent( + id="stream_1", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["sys", "foo"], + chunk="hi", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + coordinator._stream_buffers[("sys", "foo")] = [event] + coordinator._stream_positions[("sys", "foo")] = 0 + coordinator._closed_streams.add(("sys", "foo")) + + events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"])) + + assert is_complete + assert len(events) == 1 + rewritten = events[0] + assert rewritten.node_id == "response_node" + assert rewritten.id == "exec_resp" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py index 9d793f804f..55f6525bcc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py @@ -146,3 +146,4 @@ def test_serialize_tool_call_strips_files_to_ids(): assert serialized["name"] == "do" assert serialized["arguments"] == '{"a":1}' assert serialized["output"] == "ok" +