mirror of https://github.com/langgenius/dify.git
feat: basic app add thought field
This commit is contained in:
parent
047ea8c143
commit
7fc25cafb2
|
|
@ -183,7 +183,24 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
pass
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
|
|
@ -269,15 +286,20 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_messages and prompt_template:
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
return prompt_messages or []
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,67 +1,55 @@
|
|||
# Agent Patterns
|
||||
|
||||
A unified agent pattern module that provides common agent execution strategies for both Agent V2 nodes and Agent Applications in Dify.
|
||||
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||
|
||||
## Overview
|
||||
|
||||
This module implements a strategy pattern for agent execution, automatically selecting the appropriate strategy based on model capabilities. It serves as the core engine for agent-based interactions across different components of the Dify platform.
|
||||
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||
|
||||
## Key Features
|
||||
|
||||
### 1. Multiple Agent Strategies
|
||||
|
||||
- **Function Call Strategy**: Leverages native function/tool calling capabilities of advanced LLMs (e.g., GPT-4, Claude)
|
||||
- **ReAct Strategy**: Implements the ReAct (Reasoning + Acting) approach for models without native function calling support
|
||||
|
||||
### 2. Automatic Strategy Selection
|
||||
|
||||
The `StrategyFactory` intelligently selects the optimal strategy based on model features:
|
||||
|
||||
- Models with `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL` capabilities → Function Call Strategy
|
||||
- Other models → ReAct Strategy
|
||||
|
||||
### 3. Unified Interface
|
||||
|
||||
- Common base class (`AgentPattern`) ensures consistent behavior across strategies
|
||||
- Seamless integration with both workflow nodes and standalone agent applications
|
||||
- Standardized input/output formats for easy consumption
|
||||
|
||||
### 4. Advanced Capabilities
|
||||
|
||||
- **Streaming Support**: Real-time response streaming for better user experience
|
||||
- **File Handling**: Built-in support for processing and managing files during agent execution
|
||||
- **Iteration Control**: Configurable maximum iterations with safety limits (capped at 99)
|
||||
- **Tool Management**: Flexible tool integration supporting various tool types
|
||||
- **Context Propagation**: Execution context for tracing, auditing, and debugging
|
||||
- **Dual strategies**
|
||||
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||
- **Explicit or auto selection**
|
||||
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||
- **Unified execution contract**
|
||||
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||
- **Tool handling and hooks**
|
||||
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||
- **File-aware arguments**
|
||||
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||
- **ReAct prompt shaping**
|
||||
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||
- **Observability and accounting**
|
||||
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
agent/patterns/
|
||||
├── base.py # Abstract base class defining the agent pattern interface
|
||||
├── function_call.py # Implementation using native LLM function calling
|
||||
├── react.py # Implementation using ReAct prompting approach
|
||||
└── strategy_factory.py # Factory for automatic strategy selection
|
||||
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||
├── function_call.py # Native function-calling loop with tool execution
|
||||
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The module is designed to be used by:
|
||||
|
||||
1. **Agent V2 Nodes**: In workflow orchestration for complex agent tasks
|
||||
1. **Agent Applications**: For standalone conversational agents
|
||||
1. **Custom Implementations**: As a foundation for building specialized agent behaviors
|
||||
- For auto-selection:
|
||||
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||
- For explicit behavior:
|
||||
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||
|
||||
## Integration Points
|
||||
|
||||
- **Model Runtime**: Interfaces with Dify's model runtime for LLM interactions
|
||||
- **Tool System**: Integrates with the tool framework for external capabilities
|
||||
- **Memory Management**: Compatible with conversation memory systems
|
||||
- **File Management**: Handles file inputs/outputs during agent execution
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Consistency**: Unified implementation reduces code duplication and maintenance overhead
|
||||
1. **Flexibility**: Easy to extend with new strategies or customize existing ones
|
||||
1. **Performance**: Optimized for each model's capabilities to ensure best performance
|
||||
1. **Reliability**: Built-in safety mechanisms and error handling
|
||||
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||
|
|
|
|||
|
|
@ -457,6 +457,9 @@ class WorkflowBasedAppRunner:
|
|||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
|
|
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
|
|
@ -441,7 +444,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
for thought in agent_thoughts:
|
||||
# Add thought/reasoning
|
||||
if thought.thought:
|
||||
reasoning_list.append(thought.thought)
|
||||
reasoning_text = thought.thought
|
||||
if "<think" in reasoning_text.lower():
|
||||
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_text = extracted_reasoning
|
||||
thought.thought = clean_text or extracted_reasoning
|
||||
reasoning_list.append(reasoning_text)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
|
||||
# Add tool calls
|
||||
|
|
@ -464,6 +473,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
else:
|
||||
# Completion/Chat mode: use reasoning_content from llm_result
|
||||
reasoning_content = llm_result.reasoning_content
|
||||
if not reasoning_content and answer:
|
||||
# Extract reasoning from <think> blocks and clean the final answer
|
||||
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||
if reasoning_content:
|
||||
answer = clean_answer
|
||||
llm_result.message.content = clean_answer
|
||||
llm_result.reasoning_content = reasoning_content
|
||||
message.answer = clean_answer
|
||||
if reasoning_content:
|
||||
reasoning_list = [reasoning_content]
|
||||
# Content comes first, then reasoning
|
||||
|
|
@ -493,6 +510,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||
"""
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
|
|
|
|||
|
|
@ -474,57 +474,67 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
outputs = execution.outputs or {}
|
||||
metadata = execution.metadata or {}
|
||||
|
||||
# Extract reasoning_content from outputs
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
reasoning_list: list[str] = []
|
||||
if reasoning_content:
|
||||
# reasoning_content could be a string or already a list
|
||||
if isinstance(reasoning_content, str):
|
||||
reasoning_list = [reasoning_content] if reasoning_content.strip() else []
|
||||
elif isinstance(reasoning_content, list):
|
||||
# Filter out empty or whitespace-only strings
|
||||
reasoning_list = [r.strip() for r in reasoning_content if isinstance(r, str) and r.strip()]
|
||||
reasoning_list = self._extract_reasoning(outputs)
|
||||
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||
|
||||
# Extract tool_calls from metadata.agent_log
|
||||
tool_calls_list: list[dict] = []
|
||||
agent_log = metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
|
||||
if agent_log and isinstance(agent_log, list):
|
||||
for log in agent_log:
|
||||
# Each log entry has label, data, status, etc.
|
||||
log_data = log.data if hasattr(log, "data") else log.get("data", {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
# Only include tool calls with valid tool_name
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
|
||||
# Only save if there's meaningful generation detail (reasoning or tool calls)
|
||||
has_valid_reasoning = bool(reasoning_list)
|
||||
has_valid_tool_calls = bool(tool_calls_list)
|
||||
|
||||
if not has_valid_reasoning and not has_valid_tool_calls:
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
# Build sequence based on content, reasoning, and tool_calls
|
||||
sequence: list[dict] = []
|
||||
text = outputs.get("text", "")
|
||||
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||
|
||||
# For now, use a simple sequence: content -> reasoning -> tool_calls
|
||||
# This can be enhanced later to track actual streaming order
|
||||
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str):
|
||||
trimmed = reasoning_content.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(reasoning_content, list):
|
||||
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||
return []
|
||||
|
||||
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||
"""Extract tool call records from agent logs."""
|
||||
if not agent_log or not isinstance(agent_log, list):
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, str]] = []
|
||||
for log in agent_log:
|
||||
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
def _build_generation_sequence(
|
||||
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a simple content/reasoning/tool_call sequence."""
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if text:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||
for i, _ in enumerate(reasoning_list):
|
||||
sequence.append({"type": "reasoning", "index": i})
|
||||
for i in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": i})
|
||||
for index in range(len(reasoning_list)):
|
||||
sequence.append({"type": "reasoning", "index": index})
|
||||
for index in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": index})
|
||||
return sequence
|
||||
|
||||
# Check if generation detail already exists for this node execution
|
||||
def _upsert_generation_detail(
|
||||
self,
|
||||
session,
|
||||
execution: WorkflowNodeExecution,
|
||||
reasoning_list: list[str],
|
||||
tool_calls_list: list[dict[str, str]],
|
||||
sequence: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||
existing = (
|
||||
session.query(LLMGenerationDetail)
|
||||
.filter_by(
|
||||
|
|
@ -534,23 +544,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
.first()
|
||||
)
|
||||
|
||||
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
sequence_json = json.dumps(sequence) if sequence else None
|
||||
|
||||
if existing:
|
||||
# Update existing record
|
||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
# Create new record
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
existing.reasoning_content = reasoning_json
|
||||
existing.tool_calls = tool_calls_json
|
||||
existing.sequence = sequence_json
|
||||
return
|
||||
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=reasoning_json,
|
||||
tool_calls=tool_calls_json,
|
||||
sequence=sequence_json,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -391,12 +391,9 @@ class ResponseStreamCoordinator:
|
|||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
active_session = self._active_session
|
||||
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Check if there's a direct stream for this selector
|
||||
|
|
@ -404,65 +401,27 @@ class ResponseStreamCoordinator:
|
|||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||
)
|
||||
|
||||
if has_direct_stream:
|
||||
# Stream all available chunks for direct stream
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, update the event to use active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||
|
||||
if stream_targets:
|
||||
all_complete = True
|
||||
|
||||
for target_selector in stream_targets:
|
||||
while self._has_unread_stream(target_selector):
|
||||
if event := self._pop_stream_chunk(target_selector):
|
||||
events.append(
|
||||
self._rewrite_stream_event(
|
||||
event=event,
|
||||
output_node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
special_selector=bool(special_selector),
|
||||
)
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
# Check if stream is closed
|
||||
if self._is_stream_closed(segment.selector):
|
||||
is_complete = True
|
||||
if not self._is_stream_closed(target_selector):
|
||||
all_complete = False
|
||||
|
||||
else:
|
||||
# No direct stream - check for child field streams (for object types)
|
||||
child_streams = self._find_child_streams(segment.selector)
|
||||
|
||||
if child_streams:
|
||||
# Process all child streams
|
||||
all_children_complete = True
|
||||
|
||||
for child_selector in sorted(child_streams):
|
||||
# Stream all available chunks from this child
|
||||
while self._has_unread_stream(child_selector):
|
||||
if event := self._pop_stream_chunk(child_selector):
|
||||
# Forward child stream event
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
# Check if this child stream is complete
|
||||
if not self._is_stream_closed(child_selector):
|
||||
all_children_complete = False
|
||||
|
||||
# Object segment is complete only when all children are complete
|
||||
is_complete = all_children_complete
|
||||
is_complete = all_complete
|
||||
|
||||
# Fallback: check if scalar value exists in variable pool
|
||||
if not is_complete and not has_direct_stream:
|
||||
|
|
@ -485,6 +444,28 @@ class ResponseStreamCoordinator:
|
|||
|
||||
return events, is_complete
|
||||
|
||||
def _rewrite_stream_event(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent,
|
||||
output_node_id: str,
|
||||
execution_id: str,
|
||||
special_selector: bool,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Rewrite event to attribute to active response node when selector is special."""
|
||||
if not special_selector:
|
||||
return event
|
||||
|
||||
return self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
|
|
|
|||
|
|
@ -1203,6 +1203,7 @@ class Message(Base):
|
|||
.all()
|
||||
)
|
||||
|
||||
# FIXME (Novice) -- It's easy to cause N+1 query problem here.
|
||||
@property
|
||||
def generation_detail(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -695,6 +695,10 @@ class WorkflowRun(Base):
|
|||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
@property
|
||||
def outputs_as_generation(self):
|
||||
return is_generation_outputs(self.outputs_dict)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
|
|
@ -708,7 +712,7 @@ class WorkflowRun(Base):
|
|||
"inputs": self.inputs_dict,
|
||||
"status": self.status,
|
||||
"outputs": self.outputs_dict,
|
||||
"outputs_as_generation": is_generation_outputs(self.outputs_dict),
|
||||
"outputs_as_generation": self.outputs_as_generation,
|
||||
"error": self.error,
|
||||
"elapsed_time": self.elapsed_time,
|
||||
"total_tokens": self.total_tokens,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
"""
|
||||
Mark agent test modules as a package to avoid import name collisions.
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
|
||||
class DummyQueueManager:
|
||||
def __init__(self) -> None:
|
||||
self.published = []
|
||||
|
||||
def publish(self, event, publish_from: PublishFrom) -> None:
|
||||
self.published.append((event, publish_from))
|
||||
|
||||
|
||||
def test_skip_empty_final_chunk() -> None:
|
||||
queue_manager = DummyQueueManager()
|
||||
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app")
|
||||
|
||||
empty_final_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event)
|
||||
assert queue_manager.published == []
|
||||
|
||||
normal_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry=MagicMock(), event=normal_event)
|
||||
|
||||
assert len(queue_manager.published) == 1
|
||||
published_event, publish_from = queue_manager.published[0]
|
||||
assert publish_from == PublishFrom.APPLICATION_MANAGER
|
||||
assert published_event.text == "hi"
|
||||
|
||||
|
|
@ -6,6 +6,7 @@ from core.workflow.entities.tool_entities import ToolResultStatus
|
|||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
|
|
@ -13,6 +14,7 @@ from core.workflow.graph_events import (
|
|||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.template import Template, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
|
|
@ -186,3 +188,44 @@ class TestResponseCoordinatorObjectStreaming:
|
|||
assert ("node1", "generation", "content") in children
|
||||
assert ("node1", "generation", "tool_calls") in children
|
||||
assert ("node1", "generation", "thought") in children
|
||||
|
||||
def test_special_selector_rewrites_to_active_response_node(self):
|
||||
"""Ensure special selectors attribute streams to the active response node."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
graph.nodes = {"response_node": response_node}
|
||||
graph.root_node = response_node
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
coordinator.track_node_execution("response_node", "exec_resp")
|
||||
|
||||
coordinator._active_session = ResponseSession(
|
||||
node_id="response_node",
|
||||
template=Template(segments=[VariableSegment(selector=["sys", "foo"])]),
|
||||
)
|
||||
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="stream_1",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["sys", "foo"],
|
||||
chunk="hi",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
coordinator._stream_buffers[("sys", "foo")] = [event]
|
||||
coordinator._stream_positions[("sys", "foo")] = 0
|
||||
coordinator._closed_streams.add(("sys", "foo"))
|
||||
|
||||
events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"]))
|
||||
|
||||
assert is_complete
|
||||
assert len(events) == 1
|
||||
rewritten = events[0]
|
||||
assert rewritten.node_id == "response_node"
|
||||
assert rewritten.id == "exec_resp"
|
||||
|
|
|
|||
|
|
@ -146,3 +146,4 @@ def test_serialize_tool_call_strips_files_to_ids():
|
|||
assert serialized["name"] == "do"
|
||||
assert serialized["arguments"] == '{"a":1}'
|
||||
assert serialized["output"] == "ok"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue