mirror of https://github.com/langgenius/dify.git
feat: generation stream output.
This commit is contained in:
parent
2b23c43434
commit
2d2ce5df85
|
|
@ -51,7 +51,7 @@ class FunctionCallStrategy(AgentPattern):
|
|||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"round_index": iteration_step},
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
|
|
@ -249,25 +249,47 @@ class FunctionCallStrategy(AgentPattern):
|
|||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Invoke tool using base class method
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class ReActStrategy(AgentPattern):
|
|||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"round_index": iteration_step},
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
|
|
@ -385,18 +385,31 @@ class ReActStrategy(AgentPattern):
|
|||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
|
|
|
|||
|
|
@ -542,7 +542,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
round_index=event.round_index,
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
|
|
|
|||
|
|
@ -497,7 +497,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
round_index=event.round_index,
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
|
|
@ -670,7 +669,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
round_index: int | None = None,
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
|
|
@ -690,7 +688,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files or [],
|
||||
tool_error=tool_error,
|
||||
round_index=round_index,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -469,7 +469,6 @@ class WorkflowBasedAppRunner:
|
|||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
round_index=event.round_index,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
|
|
|
|||
|
|
@ -218,10 +218,6 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Thought fields (when chunk_type == THOUGHT)
|
||||
round_index: int | None = None
|
||||
"""current iteration round"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -134,10 +134,6 @@ class MessageStreamResponse(StreamResponse):
|
|||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Thought fields (when chunk_type == "thought")
|
||||
round_index: int | None = None
|
||||
"""current iteration round"""
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
|
@ -647,10 +643,6 @@ class TextChunkStreamResponse(StreamResponse):
|
|||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Thought fields (when chunk_type == THOUGHT)
|
||||
round_index: int | None = None
|
||||
"""current iteration round"""
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
|
|
|
|||
|
|
@ -224,7 +224,6 @@ class MessageCycleManager:
|
|||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
round_index: int | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
|
|
@ -237,7 +236,6 @@ class MessageCycleManager:
|
|||
:param tool_arguments: accumulated tool arguments JSON
|
||||
:param tool_files: file IDs produced by tool
|
||||
:param tool_error: error message if tool failed
|
||||
:param round_index: current iteration round
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
|
|
@ -256,7 +254,6 @@ class MessageCycleManager:
|
|||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
tool_error=tool_error,
|
||||
round_index=round_index,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
|
|
|||
|
|
@ -441,7 +441,6 @@ class ResponseStreamCoordinator:
|
|||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
round_index=event.round_index,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -51,9 +51,6 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
|||
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
|
||||
tool_error: str | None = Field(default=None, description="error message if tool failed")
|
||||
|
||||
# Thought fields (when chunk_type == THOUGHT)
|
||||
round_index: int | None = Field(default=None, description="current iteration round")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ class ThoughtChunkEvent(StreamChunkEvent):
|
|||
"""Agent thought streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
|
||||
round_index: int = Field(default=1, description="current iteration round")
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
|
|
|
|||
|
|
@ -598,7 +598,6 @@ class Node(Generic[NodeDataT]):
|
|||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.THOUGHT,
|
||||
round_index=event.round_index,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, StreamChunkEvent):
|
||||
if isinstance(event, (StreamChunkEvent, ThoughtChunkEvent)):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
|
|
@ -340,6 +340,16 @@ class LLMNode(Node[LLMNodeData]):
|
|||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
|
|
@ -470,6 +480,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
think_parser = llm_utils.ThinkTagStreamParser()
|
||||
reasoning_chunks: list[str] = []
|
||||
|
||||
# Initialize streaming metrics tracking
|
||||
start_time = request_start_time if request_start_time is not None else time.perf_counter()
|
||||
|
|
@ -498,12 +510,32 @@ class LLMNode(Node[LLMNodeData]):
|
|||
has_content = True
|
||||
|
||||
full_text_buffer.write(text_part)
|
||||
# Text output: always forward raw chunk (keep <think> tags intact)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=text_part,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Generation output: split out thoughts, forward only non-thought content chunks
|
||||
for kind, segment in think_parser.process(text_part):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
reasoning_chunks.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
|
|
@ -518,16 +550,35 @@ class LLMNode(Node[LLMNodeData]):
|
|||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
for kind, segment in think_parser.flush():
|
||||
if not segment:
|
||||
continue
|
||||
if kind == "thought":
|
||||
reasoning_chunks.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
reasoning_content = "".join(reasoning_chunks)
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
if reasoning_chunks and not reasoning_content:
|
||||
reasoning_content = "".join(reasoning_chunks)
|
||||
|
||||
# Calculate streaming metrics
|
||||
end_time = time.perf_counter()
|
||||
|
|
@ -1398,8 +1449,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||
finish_reason = None
|
||||
agent_result: AgentResult | None = None
|
||||
|
||||
# Track current round for ThoughtChunkEvent
|
||||
current_round = 1
|
||||
think_parser = llm_utils.ThinkTagStreamParser()
|
||||
reasoning_chunks: list[str] = []
|
||||
|
||||
|
|
@ -1431,12 +1480,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||
else:
|
||||
agent_logs.append(agent_log_event)
|
||||
|
||||
# Extract round number from ROUND log type
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
round_index = output.data.get("round_index")
|
||||
if isinstance(round_index, int):
|
||||
current_round = round_index
|
||||
|
||||
# Emit tool call events when tool call starts
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
|
||||
tool_name = output.data.get("tool_name", "")
|
||||
|
|
@ -1450,26 +1493,34 @@ class LLMNode(Node[LLMNodeData]):
|
|||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
is_final=True,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Emit tool result events when tool call completes
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.SUCCESS:
|
||||
# Emit tool result events when tool call completes (both success and error)
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||
tool_name = output.data.get("tool_name", "")
|
||||
tool_output = output.data.get("output", "")
|
||||
tool_call_id = output.data.get("tool_call_id", "")
|
||||
tool_files = []
|
||||
tool_error = None
|
||||
|
||||
# Extract file IDs if present
|
||||
# Extract file IDs if present (only for success case)
|
||||
files_data = output.data.get("files")
|
||||
if files_data and isinstance(files_data, list):
|
||||
tool_files = files_data
|
||||
|
||||
# Check for error in meta
|
||||
meta = output.data.get("meta")
|
||||
if meta and isinstance(meta, dict) and meta.get("error"):
|
||||
tool_error = meta.get("error")
|
||||
# Check for error from multiple sources
|
||||
if output.status == AgentLog.LogStatus.ERROR:
|
||||
# Priority: output.error > data.error > meta.error
|
||||
tool_error = output.error or output.data.get("error")
|
||||
meta = output.data.get("meta")
|
||||
if not tool_error and meta and isinstance(meta, dict):
|
||||
tool_error = meta.get("error")
|
||||
else:
|
||||
# For success case, check meta for potential errors
|
||||
meta = output.data.get("meta")
|
||||
if meta and isinstance(meta, dict) and meta.get("error"):
|
||||
tool_error = meta.get("error")
|
||||
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
|
|
@ -1478,7 +1529,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
tool_name=tool_name,
|
||||
tool_files=tool_files,
|
||||
tool_error=tool_error,
|
||||
is_final=True,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
elif isinstance(output, LLMResultChunk):
|
||||
|
|
@ -1502,7 +1553,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
round_index=current_round,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
|
|
@ -1548,7 +1598,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
round_index=current_round,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
|
|
@ -1580,7 +1629,27 @@ class LLMNode(Node[LLMNodeData]):
|
|||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
round_index=current_round,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Close tool_calls stream (already sent via ToolCallChunkEvent)
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call_id="",
|
||||
tool_name="",
|
||||
tool_arguments="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Close tool_results stream (already sent via ToolResultChunkEvent)
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk="",
|
||||
tool_call_id="",
|
||||
tool_name="",
|
||||
tool_files=[],
|
||||
tool_error=None,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -239,7 +239,6 @@ class TestThoughtChunkEvent:
|
|||
assert event.selector == ["node1", "thought"]
|
||||
assert event.chunk == "I need to query the weather..."
|
||||
assert event.chunk_type == ChunkType.THOUGHT
|
||||
assert event.round_index == 1 # default
|
||||
|
||||
def test_chunk_type_is_thought(self):
|
||||
"""Test that chunk_type is always THOUGHT."""
|
||||
|
|
@ -250,38 +249,17 @@ class TestThoughtChunkEvent:
|
|||
|
||||
assert event.chunk_type == ChunkType.THOUGHT
|
||||
|
||||
def test_round_index_default(self):
|
||||
"""Test that round_index defaults to 1."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="thinking...",
|
||||
)
|
||||
|
||||
assert event.round_index == 1
|
||||
|
||||
def test_round_index_custom(self):
|
||||
"""Test custom round_index."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="Second round thinking...",
|
||||
round_index=2,
|
||||
)
|
||||
|
||||
assert event.round_index == 2
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="I need to analyze this...",
|
||||
round_index=3,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "thought"
|
||||
assert data["round_index"] == 3
|
||||
assert data["chunk"] == "I need to analyze this..."
|
||||
assert data["is_final"] is False
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue