feat: generation stream output.

This commit is contained in:
Novice 2025-12-09 16:22:17 +08:00
parent 2b23c43434
commit 2d2ce5df85
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
14 changed files with 160 additions and 104 deletions

View File

@ -51,7 +51,7 @@ class FunctionCallStrategy(AgentPattern):
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"round_index": iteration_step},
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
@ -249,25 +249,47 @@ class FunctionCallStrategy(AgentPattern):
)
yield tool_call_log
# Invoke tool using base class method
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
)
return response_content, tool_files, tool_invoke_meta
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = error_message
tool_call_log.data = {
**tool_call_log.data,
"error": error_message,
}
yield tool_call_log
# Add error message to conversation
error_content = f"Tool execution failed: {error_message}"
messages.append(
ToolPromptMessage(
content=error_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return error_content, [], None

View File

@ -80,7 +80,7 @@ class ReActStrategy(AgentPattern):
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"round_index": iteration_step},
data={},
)
yield round_log
@ -385,18 +385,31 @@ class ReActStrategy(AgentPattern):
else:
tool_args_dict = tool_args
# Invoke tool using base class method
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Finish tool log
yield self._finish_log(
tool_log,
data={
# Finish tool log
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
return response_content or "Tool executed successfully", tool_files
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_log.status = AgentLog.LogStatus.ERROR
tool_log.error = error_message
tool_log.data = {
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
"error": error_message,
}
yield tool_log
return response_content or "Tool executed successfully", tool_files
return f"Tool execution failed: {error_message}", []

View File

@ -542,7 +542,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
round_index=event.round_index,
)
def _handle_iteration_start_event(

View File

@ -497,7 +497,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
round_index=event.round_index,
)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
@ -670,7 +669,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
round_index: int | None = None,
) -> TextChunkStreamResponse:
"""
Handle completed event.
@ -690,7 +688,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tool_arguments=tool_arguments,
tool_files=tool_files or [],
tool_error=tool_error,
round_index=round_index,
),
)

View File

@ -469,7 +469,6 @@ class WorkflowBasedAppRunner:
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
round_index=event.round_index,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):

View File

@ -218,10 +218,6 @@ class QueueTextChunkEvent(AppQueueEvent):
tool_error: str | None = None
"""error message if tool failed"""
# Thought fields (when chunk_type == THOUGHT)
round_index: int | None = None
"""current iteration round"""
class QueueAgentMessageEvent(AppQueueEvent):
"""

View File

@ -134,10 +134,6 @@ class MessageStreamResponse(StreamResponse):
tool_error: str | None = None
"""error message if tool failed"""
# Thought fields (when chunk_type == "thought")
round_index: int | None = None
"""current iteration round"""
class MessageAudioStreamResponse(StreamResponse):
"""
@ -647,10 +643,6 @@ class TextChunkStreamResponse(StreamResponse):
tool_error: str | None = None
"""error message if tool failed"""
# Thought fields (when chunk_type == THOUGHT)
round_index: int | None = None
"""current iteration round"""
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data

View File

@ -224,7 +224,6 @@ class MessageCycleManager:
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
round_index: int | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
@ -237,7 +236,6 @@ class MessageCycleManager:
:param tool_arguments: accumulated tool arguments JSON
:param tool_files: file IDs produced by tool
:param tool_error: error message if tool failed
:param round_index: current iteration round
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
@ -256,7 +254,6 @@ class MessageCycleManager:
tool_arguments=tool_arguments,
tool_files=tool_files,
tool_error=tool_error,
round_index=round_index,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -441,7 +441,6 @@ class ResponseStreamCoordinator:
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
round_index=event.round_index,
)
events.append(updated_event)
else:

View File

@ -51,9 +51,6 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
tool_error: str | None = Field(default=None, description="error message if tool failed")
# Thought fields (when chunk_type == THOUGHT)
round_index: int | None = Field(default=None, description="current iteration round")
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")

View File

@ -74,7 +74,6 @@ class ThoughtChunkEvent(StreamChunkEvent):
"""Agent thought streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
round_index: int = Field(default=1, description="current iteration round")
class StreamCompletedEvent(NodeEventBase):

View File

@ -598,7 +598,6 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.THOUGHT,
round_index=event.round_index,
)
@_dispatch.register

View File

@ -277,7 +277,7 @@ class LLMNode(Node[LLMNodeData]):
structured_output: LLMStructuredOutput | None = None
for event in generator:
if isinstance(event, StreamChunkEvent):
if isinstance(event, (StreamChunkEvent, ThoughtChunkEvent)):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
# Raw text
@ -340,6 +340,16 @@ class LLMNode(Node[LLMNodeData]):
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
@ -470,6 +480,8 @@ class LLMNode(Node[LLMNodeData]):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
think_parser = llm_utils.ThinkTagStreamParser()
reasoning_chunks: list[str] = []
# Initialize streaming metrics tracking
start_time = request_start_time if request_start_time is not None else time.perf_counter()
@ -498,12 +510,32 @@ class LLMNode(Node[LLMNodeData]):
has_content = True
full_text_buffer.write(text_part)
# Text output: always forward raw chunk (keep <think> tags intact)
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=text_part,
is_final=False,
)
# Generation output: split out thoughts, forward only non-thought content chunks
for kind, segment in think_parser.process(text_part):
if not segment:
continue
if kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
# Update the whole metadata
if not model and result.model:
model = result.model
@ -518,16 +550,35 @@ class LLMNode(Node[LLMNodeData]):
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
for kind, segment in think_parser.flush():
if not segment:
continue
if kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
reasoning_content = "".join(reasoning_chunks)
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
if reasoning_chunks and not reasoning_content:
reasoning_content = "".join(reasoning_chunks)
# Calculate streaming metrics
end_time = time.perf_counter()
@ -1398,8 +1449,6 @@ class LLMNode(Node[LLMNodeData]):
finish_reason = None
agent_result: AgentResult | None = None
# Track current round for ThoughtChunkEvent
current_round = 1
think_parser = llm_utils.ThinkTagStreamParser()
reasoning_chunks: list[str] = []
@ -1431,12 +1480,6 @@ class LLMNode(Node[LLMNodeData]):
else:
agent_logs.append(agent_log_event)
# Extract round number from ROUND log type
if output.log_type == AgentLog.LogType.ROUND:
round_index = output.data.get("round_index")
if isinstance(round_index, int):
current_round = round_index
# Emit tool call events when tool call starts
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
tool_name = output.data.get("tool_name", "")
@ -1450,26 +1493,34 @@ class LLMNode(Node[LLMNodeData]):
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
is_final=True,
is_final=False,
)
# Emit tool result events when tool call completes
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.SUCCESS:
# Emit tool result events when tool call completes (both success and error)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = output.data.get("tool_name", "")
tool_output = output.data.get("output", "")
tool_call_id = output.data.get("tool_call_id", "")
tool_files = []
tool_error = None
# Extract file IDs if present
# Extract file IDs if present (only for success case)
files_data = output.data.get("files")
if files_data and isinstance(files_data, list):
tool_files = files_data
# Check for error in meta
meta = output.data.get("meta")
if meta and isinstance(meta, dict) and meta.get("error"):
tool_error = meta.get("error")
# Check for error from multiple sources
if output.status == AgentLog.LogStatus.ERROR:
# Priority: output.error > data.error > meta.error
tool_error = output.error or output.data.get("error")
meta = output.data.get("meta")
if not tool_error and meta and isinstance(meta, dict):
tool_error = meta.get("error")
else:
# For success case, check meta for potential errors
meta = output.data.get("meta")
if meta and isinstance(meta, dict) and meta.get("error"):
tool_error = meta.get("error")
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
@ -1478,7 +1529,7 @@ class LLMNode(Node[LLMNodeData]):
tool_name=tool_name,
tool_files=tool_files,
tool_error=tool_error,
is_final=True,
is_final=False,
)
elif isinstance(output, LLMResultChunk):
@ -1502,7 +1553,6 @@ class LLMNode(Node[LLMNodeData]):
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
round_index=current_round,
is_final=False,
)
else:
@ -1548,7 +1598,6 @@ class LLMNode(Node[LLMNodeData]):
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
round_index=current_round,
is_final=False,
)
else:
@ -1580,7 +1629,27 @@ class LLMNode(Node[LLMNodeData]):
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
round_index=current_round,
is_final=True,
)
# Close tool_calls stream (already sent via ToolCallChunkEvent)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call_id="",
tool_name="",
tool_arguments="",
is_final=True,
)
# Close tool_results stream (already sent via ToolResultChunkEvent)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_call_id="",
tool_name="",
tool_files=[],
tool_error=None,
is_final=True,
)

View File

@ -239,7 +239,6 @@ class TestThoughtChunkEvent:
assert event.selector == ["node1", "thought"]
assert event.chunk == "I need to query the weather..."
assert event.chunk_type == ChunkType.THOUGHT
assert event.round_index == 1 # default
def test_chunk_type_is_thought(self):
"""Test that chunk_type is always THOUGHT."""
@ -250,38 +249,17 @@ class TestThoughtChunkEvent:
assert event.chunk_type == ChunkType.THOUGHT
def test_round_index_default(self):
"""Test that round_index defaults to 1."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="thinking...",
)
assert event.round_index == 1
def test_round_index_custom(self):
"""Test custom round_index."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="Second round thinking...",
round_index=2,
)
assert event.round_index == 2
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="I need to analyze this...",
round_index=3,
is_final=False,
)
data = event.model_dump()
assert data["chunk_type"] == "thought"
assert data["round_index"] == 3
assert data["chunk"] == "I need to analyze this..."
assert data["is_final"] is False