diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index f57e251cdf..c1938fb5e3 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,14 +1,17 @@ +import re from collections.abc import Mapping, Sequence from typing import Any, Literal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator +from core.agent.entities import AgentLog, AgentResult from core.file import File from core.model_runtime.entities import ImagePromptMessageContent, LLMMode from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType from core.workflow.entities import ToolCallResult +from core.workflow.node_events import AgentLogEvent from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector @@ -20,44 +23,6 @@ class ModelConfig(BaseModel): completion_params: dict[str, Any] = Field(default_factory=dict) -class LLMTraceSegment(BaseModel): - """ - Streaming trace segment for LLM tool-enabled runs. - - Order is preserved for replay. Tool calls are single entries containing both - arguments and results. - """ - - type: Literal["thought", "content", "tool_call"] - - # Common optional fields - text: str | None = Field(None, description="Text chunk for thought/content") - - # Tool call fields (combined start + result) - tool_call: ToolCallResult | None = Field( - default=None, - description="Combined tool call arguments and result for this segment", - ) - - -class LLMGenerationData(BaseModel): - """Generation data from LLM invocation with tools. - - For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2 - - reasoning_contents: [thought1, thought2, ...] - one element per turn - - tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results - """ - - text: str = Field(..., description="Accumulated text content from all turns") - reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") - tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results") - sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") - usage: LLMUsage = Field(..., description="LLM usage statistics") - finish_reason: str | None = Field(None, description="Finish reason from LLM") - files: list[File] = Field(default_factory=list, description="Generated files") - trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") - - class ContextConfig(BaseModel): enabled: bool variable_selector: list[str] | None = None @@ -124,6 +89,211 @@ class ToolMetadata(BaseModel): extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description") +class LLMTraceSegment(BaseModel): + """ + Streaming trace segment for LLM tool-enabled runs. + + Order is preserved for replay. Tool calls are single entries containing both + arguments and results. + """ + + type: Literal["thought", "content", "tool_call"] + + # Common optional fields + text: str | None = Field(None, description="Text chunk for thought/content") + + # Tool call fields (combined start + result) + tool_call: ToolCallResult | None = Field( + default=None, + description="Combined tool call arguments and result for this segment", + ) + + +class LLMGenerationData(BaseModel): + """Generation data from LLM invocation with tools. + + For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2 + - reasoning_contents: [thought1, thought2, ...] - one element per turn + - tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results + """ + + text: str = Field(..., description="Accumulated text content from all turns") + reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") + tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results") + sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") + usage: LLMUsage = Field(..., description="LLM usage statistics") + finish_reason: str | None = Field(None, description="Finish reason from LLM") + files: list[File] = Field(default_factory=list, description="Generated files") + trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") + + +class ThinkTagStreamParser: + """Lightweight state machine to split streaming chunks by tags.""" + + _START_PATTERN = re.compile(r"]*)?>", re.IGNORECASE) + _END_PATTERN = re.compile(r"", re.IGNORECASE) + _START_PREFIX = " int: + """Return length of the longest suffix of `text` that is a prefix of `prefix`.""" + max_len = min(len(text), len(prefix) - 1) + for i in range(max_len, 0, -1): + if text[-i:].lower() == prefix[:i].lower(): + return i + return 0 + + def process(self, chunk: str) -> list[tuple[str, str]]: + """ + Split incoming chunk into ('thought' | 'text', content) tuples. + Content excludes the tags themselves and handles split tags across chunks. + """ + parts: list[tuple[str, str]] = [] + self._buffer += chunk + + while self._buffer: + if self._in_think: + end_match = self._END_PATTERN.search(self._buffer) + if end_match: + thought_text = self._buffer[: end_match.start()] + if thought_text: + parts.append(("thought", thought_text)) + self._buffer = self._buffer[end_match.end() :] + self._in_think = False + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("thought", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + start_match = self._START_PATTERN.search(self._buffer) + if start_match: + prefix = self._buffer[: start_match.start()] + if prefix: + parts.append(("text", prefix)) + self._buffer = self._buffer[start_match.end() :] + self._in_think = True + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("text", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + cleaned_parts: list[tuple[str, str]] = [] + for kind, content in parts: + # Extra safeguard: strip any stray tags that slipped through. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + if content: + cleaned_parts.append((kind, content)) + + return cleaned_parts + + def flush(self) -> list[tuple[str, str]]: + """Flush remaining buffer when the stream ends.""" + if not self._buffer: + return [] + kind = "thought" if self._in_think else "text" + content = self._buffer + # Drop dangling partial tags instead of emitting them + if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX): + content = "" + self._buffer = "" + if not content: + return [] + # Strip any complete tags that might still be present. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + return [(kind, content)] if content else [] + + +class StreamBuffers(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser) + pending_thought: list[str] = Field(default_factory=list) + pending_content: list[str] = Field(default_factory=list) + current_turn_reasoning: list[str] = Field(default_factory=list) + reasoning_per_turn: list[str] = Field(default_factory=list) + + +class TraceState(BaseModel): + trace_segments: list[LLMTraceSegment] = Field(default_factory=list) + tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict) + tool_call_index_map: dict[str, int] = Field(default_factory=dict) + + +class AggregatedResult(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + text: str = "" + files: list[File] = Field(default_factory=list) + usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) + finish_reason: str | None = None + + +class AgentContext(BaseModel): + agent_logs: list[AgentLogEvent] = Field(default_factory=list) + agent_result: AgentResult | None = None + + +class ToolOutputState(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + stream: StreamBuffers = Field(default_factory=StreamBuffers) + trace: TraceState = Field(default_factory=TraceState) + aggregate: AggregatedResult = Field(default_factory=AggregatedResult) + agent: AgentContext = Field(default_factory=AgentContext) + + +class ToolLogPayload(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + tool_name: str = "" + tool_call_id: str = "" + tool_args: dict[str, Any] = Field(default_factory=dict) + tool_output: Any = None + tool_error: Any = None + files: list[Any] = Field(default_factory=list) + meta: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_log(cls, log: AgentLog) -> "ToolLogPayload": + data = log.data or {} + return cls( + tool_name=data.get("tool_name", ""), + tool_call_id=data.get("tool_call_id", ""), + tool_args=data.get("tool_args") or {}, + tool_output=data.get("output"), + tool_error=data.get("error"), + files=data.get("files") or [], + meta=data.get("meta") or {}, + ) + + @classmethod + def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload": + return cls( + tool_name=data.get("tool_name", ""), + tool_call_id=data.get("tool_call_id", ""), + tool_args=data.get("tool_args") or {}, + tool_output=data.get("output"), + tool_error=data.get("error"), + files=data.get("files") or [], + meta=data.get("meta") or {}, + ) + + class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index e9c363851f..0c545469bc 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,4 +1,3 @@ -import re from collections.abc import Sequence from typing import cast @@ -155,94 +154,3 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs ) session.execute(stmt) session.commit() - - -class ThinkTagStreamParser: - """Lightweight state machine to split streaming chunks by tags.""" - - _START_PATTERN = re.compile(r"]*)?>", re.IGNORECASE) - _END_PATTERN = re.compile(r"", re.IGNORECASE) - _START_PREFIX = " int: - """Return length of the longest suffix of `text` that is a prefix of `prefix`.""" - max_len = min(len(text), len(prefix) - 1) - for i in range(max_len, 0, -1): - if text[-i:].lower() == prefix[:i].lower(): - return i - return 0 - - def process(self, chunk: str) -> list[tuple[str, str]]: - """ - Split incoming chunk into ('thought' | 'text', content) tuples. - Content excludes the tags themselves and handles split tags across chunks. - """ - parts: list[tuple[str, str]] = [] - self._buffer += chunk - - while self._buffer: - if self._in_think: - end_match = self._END_PATTERN.search(self._buffer) - if end_match: - thought_text = self._buffer[: end_match.start()] - if thought_text: - parts.append(("thought", thought_text)) - self._buffer = self._buffer[end_match.end() :] - self._in_think = False - continue - - hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX) - emit = self._buffer[: len(self._buffer) - hold_len] - if emit: - parts.append(("thought", emit)) - self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" - break - - start_match = self._START_PATTERN.search(self._buffer) - if start_match: - prefix = self._buffer[: start_match.start()] - if prefix: - parts.append(("text", prefix)) - self._buffer = self._buffer[start_match.end() :] - self._in_think = True - continue - - hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX) - emit = self._buffer[: len(self._buffer) - hold_len] - if emit: - parts.append(("text", emit)) - self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" - break - - cleaned_parts: list[tuple[str, str]] = [] - for kind, content in parts: - # Extra safeguard: strip any stray tags that slipped through. - content = self._START_PATTERN.sub("", content) - content = self._END_PATTERN.sub("", content) - if content: - cleaned_parts.append((kind, content)) - - return cleaned_parts - - def flush(self) -> list[tuple[str, str]]: - """Flush remaining buffer when the stream ends.""" - if not self._buffer: - return [] - kind = "thought" if self._in_think else "text" - content = self._buffer - # Drop dangling partial tags instead of emitting them - if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX): - content = "" - self._buffer = "" - if not content: - return [] - # Strip any complete tags that might still be present. - content = self._START_PATTERN.sub("", content) - content = self._END_PATTERN.sub("", content) - return [(kind, content)] if content else [] diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 94b616bd34..408363d226 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -90,12 +90,19 @@ from models.model import UploadFile from . import llm_utils from .entities import ( + AgentContext, + AggregatedResult, LLMGenerationData, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, LLMTraceSegment, ModelConfig, + StreamBuffers, + ThinkTagStreamParser, + ToolLogPayload, + ToolOutputState, + TraceState, ) from .exc import ( InvalidContextStructureError, @@ -582,7 +589,7 @@ class LLMNode(Node[LLMNodeData]): usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() - think_parser = llm_utils.ThinkTagStreamParser() + think_parser = ThinkTagStreamParser() reasoning_chunks: list[str] = [] # Initialize streaming metrics tracking @@ -1495,7 +1502,7 @@ class LLMNode(Node[LLMNodeData]): ) # Process outputs and return generation result - result = yield from self._process_tool_outputs(outputs, strategy, node_inputs, process_data) + result = yield from self._process_tool_outputs(outputs) return result def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: @@ -1587,278 +1594,213 @@ class LLMNode(Node[LLMNodeData]): return files - def _process_tool_outputs( - self, - outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], - strategy: Any, - node_inputs: dict[str, Any], - process_data: dict[str, Any], - ) -> Generator[NodeEventBase, None, LLMGenerationData]: - """Process strategy outputs and convert to node events. + def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None: + if not buffers.pending_thought: + return + trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought))) + buffers.pending_thought.clear() - Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files - """ - text = "" - files: list[File] = [] - usage = LLMUsage.empty_usage() - agent_logs: list[AgentLogEvent] = [] - finish_reason = None - agent_result: AgentResult | None = None + def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None: + if not buffers.pending_content: + return + trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content))) + buffers.pending_content.clear() - think_parser = llm_utils.ThinkTagStreamParser() - # Track reasoning per turn: each tool_call completion marks end of a turn - current_turn_reasoning: list[str] = [] # Buffer for current turn's thought chunks - reasoning_per_turn: list[str] = [] # Final list: one element per turn - tool_call_index_map: dict[str, int] = {} # tool_call_id -> index - trace_segments: list[LLMTraceSegment] = [] # Ordered trace for replay - tool_trace_map: dict[str, LLMTraceSegment] = {} - current_turn = 0 - pending_thought: list[str] = [] - pending_content: list[str] = [] + def _handle_agent_log_output( + self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext + ) -> Generator[NodeEventBase, None, None]: + payload = ToolLogPayload.from_log(output) + agent_log_event = AgentLogEvent( + message_id=output.id, + label=output.label, + node_execution_id=self.id, + parent_id=output.parent_id, + error=output.error, + status=output.status.value, + data=output.data, + metadata={k.value: v for k, v in output.metadata.items()}, + node_id=self._node_id, + ) + for log in agent_context.agent_logs: + if log.message_id == agent_log_event.message_id: + log.data = agent_log_event.data + log.status = agent_log_event.status + log.error = agent_log_event.error + log.label = agent_log_event.label + log.metadata = agent_log_event.metadata + break + else: + agent_context.agent_logs.append(agent_log_event) - def _flush_thought() -> None: - if not pending_thought: - return - trace_segments.append(LLMTraceSegment(type="thought", text="".join(pending_thought))) - pending_thought.clear() + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: + tool_name = payload.tool_name + tool_call_id = payload.tool_call_id + tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else "" - def _flush_content() -> None: - if not pending_content: - return - trace_segments.append(LLMTraceSegment(type="content", text="".join(pending_content))) - pending_content.clear() + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: + trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) - # Process each output from strategy - try: - for output in outputs: - if isinstance(output, AgentLog): - # Store agent log event for metadata (no longer yielded, StreamChunkEvent contains the info) - agent_log_event = AgentLogEvent( - message_id=output.id, - label=output.label, - node_execution_id=self.id, - parent_id=output.parent_id, - error=output.error, - status=output.status.value, - data=output.data, - metadata={k.value: v for k, v in output.metadata.items()}, - node_id=self._node_id, + self._flush_thought_segment(buffers, trace_state) + self._flush_content_segment(buffers, trace_state) + + tool_call_segment = LLMTraceSegment( + type="tool_call", + text=None, + tool_call=ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + ), + ) + trace_state.trace_segments.append(tool_call_segment) + if tool_call_id: + trace_state.tool_trace_map[tool_call_id] = tool_call_segment + + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk=tool_arguments, + tool_call=ToolCall( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + ), + is_final=False, + ) + + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START: + tool_name = payload.tool_name + tool_output = payload.tool_output + tool_call_id = payload.tool_call_id + tool_files = payload.files if isinstance(payload.files, list) else [] + tool_error = payload.tool_error + + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: + trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) + + self._flush_thought_segment(buffers, trace_state) + self._flush_content_segment(buffers, trace_state) + + if output.status == AgentLog.LogStatus.ERROR: + tool_error = output.error or payload.tool_error + if not tool_error and payload.meta: + tool_error = payload.meta.get("error") + else: + if payload.meta: + meta_error = payload.meta.get("error") + if meta_error: + tool_error = meta_error + + existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id) + tool_call_segment = existing_tool_segment or LLMTraceSegment( + type="tool_call", + text=None, + tool_call=ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=None, + ), + ) + if existing_tool_segment is None: + trace_state.trace_segments.append(tool_call_segment) + if tool_call_id: + trace_state.tool_trace_map[tool_call_id] = tool_call_segment + + if tool_call_segment.tool_call is None: + tool_call_segment.tool_call = ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=None, + ) + tool_call_segment.tool_call.output = ( + str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None + ) + tool_call_segment.tool_call.files = [] + tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS + + result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None + + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk=result_output or "", + tool_result=ToolResult( + id=tool_call_id, + name=tool_name, + output=result_output, + files=tool_files, + status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, + ), + is_final=False, + ) + + if buffers.current_turn_reasoning: + buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) + buffers.current_turn_reasoning.clear() + + def _handle_llm_chunk_output( + self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult + ) -> Generator[NodeEventBase, None, None]: + message = output.delta.message + + if message and message.content: + chunk_text = message.content + if isinstance(chunk_text, list): + chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text) + else: + chunk_text = str(chunk_text) + + for kind, segment in buffers.think_parser.process(chunk_text): + if not segment: + continue + + if kind == "thought": + self._flush_content_segment(buffers, trace_state) + buffers.current_turn_reasoning.append(segment) + buffers.pending_thought.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + else: + self._flush_thought_segment(buffers, trace_state) + aggregate.text += segment + buffers.pending_content.append(segment) + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, ) - for log in agent_logs: - if log.message_id == agent_log_event.message_id: - # update the log - log.data = agent_log_event.data - log.status = agent_log_event.status - log.error = agent_log_event.error - log.label = agent_log_event.label - log.metadata = agent_log_event.metadata - break - else: - agent_logs.append(agent_log_event) - # Emit tool call events when tool call starts - if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: - tool_name = output.data.get("tool_name", "") - tool_call_id = output.data.get("tool_call_id", "") - tool_args = output.data.get("tool_args", {}) - tool_arguments = json.dumps(tool_args) if tool_args else "" + if output.delta.usage: + self._accumulate_usage(aggregate.usage, output.delta.usage) - if tool_call_id and tool_call_id not in tool_call_index_map: - tool_call_index_map[tool_call_id] = len(tool_call_index_map) + if output.delta.finish_reason: + aggregate.finish_reason = output.delta.finish_reason - _flush_thought() - _flush_content() - - tool_call_segment = LLMTraceSegment( - type="tool_call", - text=None, - tool_call=ToolCallResult( - id=tool_call_id, - name=tool_name, - arguments=tool_arguments, - ), - ) - trace_segments.append(tool_call_segment) - if tool_call_id: - tool_trace_map[tool_call_id] = tool_call_segment - - yield ToolCallChunkEvent( - selector=[self._node_id, "generation", "tool_calls"], - chunk=tool_arguments, - tool_call=ToolCall( - id=tool_call_id, - name=tool_name, - arguments=tool_arguments, - ), - is_final=False, - ) - - # Emit tool result events when tool call completes (both success and error) - if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START: - tool_name = output.data.get("tool_name", "") - tool_output = output.data.get("output", "") - tool_call_id = output.data.get("tool_call_id", "") - tool_files = [] - tool_error = None - - if tool_call_id and tool_call_id not in tool_call_index_map: - tool_call_index_map[tool_call_id] = len(tool_call_index_map) - - _flush_thought() - _flush_content() - - # Extract file IDs if present (only for success case) - files_data = output.data.get("files") - if files_data and isinstance(files_data, list): - tool_files = files_data - - # Check for error from multiple sources - if output.status == AgentLog.LogStatus.ERROR: - # Priority: output.error > data.error > meta.error - tool_error = output.error or output.data.get("error") - meta = output.data.get("meta") - if not tool_error and meta and isinstance(meta, dict): - tool_error = meta.get("error") - else: - # For success case, check meta for potential errors - meta = output.data.get("meta") - if meta and isinstance(meta, dict) and meta.get("error"): - tool_error = meta.get("error") - - existing_tool_segment = tool_trace_map.get(tool_call_id) - tool_call_segment = existing_tool_segment or LLMTraceSegment( - type="tool_call", - text=None, - tool_call=ToolCallResult( - id=tool_call_id, - name=tool_name, - arguments=None, - ), - ) - if existing_tool_segment is None: - trace_segments.append(tool_call_segment) - if tool_call_id: - tool_trace_map[tool_call_id] = tool_call_segment - - if tool_call_segment.tool_call is None: - tool_call_segment.tool_call = ToolCallResult( - id=tool_call_id, - name=tool_name, - arguments=None, - ) - tool_call_segment.tool_call.output = ( - str(tool_output) - if tool_output is not None - else str(tool_error) - if tool_error is not None - else None - ) - tool_call_segment.tool_call.files = [] - tool_call_segment.tool_call.status = ( - ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS - ) - current_turn += 1 - - result_output = ( - str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None - ) - - yield ToolResultChunkEvent( - selector=[self._node_id, "generation", "tool_results"], - chunk=result_output or "", - tool_result=ToolResult( - id=tool_call_id, - name=tool_name, - output=result_output, - files=tool_files, - status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, - ), - is_final=False, - ) - - # End of current turn: save accumulated thought as one element - if current_turn_reasoning: - reasoning_per_turn.append("".join(current_turn_reasoning)) - current_turn_reasoning.clear() - - elif isinstance(output, LLMResultChunk): - # Handle LLM result chunks - only process text content - message = output.delta.message - - # Handle text content - if message and message.content: - chunk_text = message.content - if isinstance(chunk_text, list): - # Extract text from content list - chunk_text = "".join(getattr(c, "data", str(c)) for c in chunk_text) - else: - chunk_text = str(chunk_text) - for kind, segment in think_parser.process(chunk_text): - if not segment: - continue - - if kind == "thought": - _flush_content() - current_turn_reasoning.append(segment) - pending_thought.append(segment) - yield ThoughtChunkEvent( - selector=[self._node_id, "generation", "thought"], - chunk=segment, - is_final=False, - ) - else: - _flush_thought() - text += segment - pending_content.append(segment) - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=segment, - is_final=False, - ) - yield StreamChunkEvent( - selector=[self._node_id, "generation", "content"], - chunk=segment, - is_final=False, - ) - - if output.delta.usage: - self._accumulate_usage(usage, output.delta.usage) - - # Capture finish reason - if output.delta.finish_reason: - finish_reason = output.delta.finish_reason - - except StopIteration as e: - # Get the return value from generator - if isinstance(getattr(e, "value", None), AgentResult): - agent_result = e.value - - # Use result from generator if available - if agent_result: - text = agent_result.text or text - files = agent_result.files - if agent_result.usage: - usage = agent_result.usage - if agent_result.finish_reason: - finish_reason = agent_result.finish_reason - - # Flush any remaining buffered content after streaming ends - for kind, segment in think_parser.flush(): + def _flush_remaining_stream( + self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult + ) -> Generator[NodeEventBase, None, None]: + for kind, segment in buffers.think_parser.flush(): if not segment: continue if kind == "thought": - _flush_content() - current_turn_reasoning.append(segment) - pending_thought.append(segment) + self._flush_content_segment(buffers, trace_state) + buffers.current_turn_reasoning.append(segment) + buffers.pending_thought.append(segment) yield ThoughtChunkEvent( selector=[self._node_id, "generation", "thought"], chunk=segment, is_final=False, ) else: - _flush_thought() - text += segment - pending_content.append(segment) + self._flush_thought_segment(buffers, trace_state) + aggregate.text += segment + buffers.pending_content.append(segment) yield StreamChunkEvent( selector=[self._node_id, "text"], chunk=segment, @@ -1870,19 +1812,63 @@ class LLMNode(Node[LLMNodeData]): is_final=False, ) - # Save the last turn's thought if any - if current_turn_reasoning: - reasoning_per_turn.append("".join(current_turn_reasoning)) + if buffers.current_turn_reasoning: + buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) - _flush_thought() - _flush_content() + self._flush_thought_segment(buffers, trace_state) + self._flush_content_segment(buffers, trace_state) - # Build sequence from trace_segments for rendering + def _close_streams(self) -> Generator[NodeEventBase, None, None]: + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk="", + is_final=True, + ) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=True, + ) + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk="", + tool_call=ToolCall( + id="", + name="", + arguments="", + ), + is_final=True, + ) + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk="", + tool_result=ToolResult( + id="", + name="", + output="", + files=[], + status=ToolResultStatus.SUCCESS, + ), + is_final=True, + ) + + def _build_generation_data( + self, + trace_state: TraceState, + agent_context: AgentContext, + aggregate: AggregatedResult, + buffers: StreamBuffers, + ) -> LLMGenerationData: sequence: list[dict[str, Any]] = [] reasoning_index = 0 content_position = 0 tool_call_seen_index: dict[str, int] = {} - for trace_segment in trace_segments: + for trace_segment in trace_state.trace_segments: if trace_segment.type == "thought": sequence.append({"type": "reasoning", "index": reasoning_index}) reasoning_index += 1 @@ -1898,67 +1884,22 @@ class LLMNode(Node[LLMNodeData]): tool_call_seen_index[tool_id] = len(tool_call_seen_index) sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) - # Send final events for all streams - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Close generation sub-field streams - yield StreamChunkEvent( - selector=[self._node_id, "generation", "content"], - chunk="", - is_final=True, - ) - yield ThoughtChunkEvent( - selector=[self._node_id, "generation", "thought"], - chunk="", - is_final=True, - ) - - # Close tool_calls stream (already sent via ToolCallChunkEvent) - yield ToolCallChunkEvent( - selector=[self._node_id, "generation", "tool_calls"], - chunk="", - tool_call=ToolCall( - id="", - name="", - arguments="", - ), - is_final=True, - ) - - # Close tool_results stream (already sent via ToolResultChunkEvent) - yield ToolResultChunkEvent( - selector=[self._node_id, "generation", "tool_results"], - chunk="", - tool_result=ToolResult( - id="", - name="", - output="", - files=[], - status=ToolResultStatus.SUCCESS, - ), - is_final=True, - ) - - # Build tool_calls from agent_logs (with results) tool_calls_for_generation: list[ToolCallResult] = [] - for log in agent_logs: - tool_call_id = log.data.get("tool_call_id") + for log in agent_context.agent_logs: + payload = ToolLogPayload.from_mapping(log.data or {}) + tool_call_id = payload.tool_call_id if not tool_call_id or log.status == AgentLog.LogStatus.START.value: continue - tool_args = log.data.get("tool_args") or {} - log_error = log.data.get("error") - log_output = log.data.get("output") + tool_args = payload.tool_args + log_error = payload.tool_error + log_output = payload.tool_output result_text = log_output or log_error or "" status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS tool_calls_for_generation.append( ToolCallResult( id=tool_call_id, - name=log.data.get("tool_name", ""), + name=payload.tool_name, arguments=json.dumps(tool_args) if tool_args else "", output=result_text, status=status, @@ -1966,21 +1907,50 @@ class LLMNode(Node[LLMNodeData]): ) tool_calls_for_generation.sort( - key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map)) + key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map)) ) - # Return generation data for caller return LLMGenerationData( - text=text, - reasoning_contents=reasoning_per_turn, + text=aggregate.text, + reasoning_contents=buffers.reasoning_per_turn, tool_calls=tool_calls_for_generation, sequence=sequence, - usage=usage, - finish_reason=finish_reason, - files=files, - trace=trace_segments, + usage=aggregate.usage, + finish_reason=aggregate.finish_reason, + files=aggregate.files, + trace=trace_state.trace_segments, ) + def _process_tool_outputs( + self, + outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], + ) -> Generator[NodeEventBase, None, LLMGenerationData]: + """Process strategy outputs and convert to node events.""" + state = ToolOutputState() + + try: + for output in outputs: + if isinstance(output, AgentLog): + yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent) + else: + yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate) + except StopIteration as exception: + if isinstance(getattr(exception, "value", None), AgentResult): + state.agent.agent_result = exception.value + + if state.agent.agent_result: + state.aggregate.text = state.agent.agent_result.text or state.aggregate.text + state.aggregate.files = state.agent.agent_result.files + if state.agent.agent_result.usage: + state.aggregate.usage = state.agent.agent_result.usage + if state.agent.agent_result.finish_reason: + state.aggregate.finish_reason = state.agent.agent_result.finish_reason + + yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) + yield from self._close_streams() + + return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream) + def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None: """Accumulate LLM usage statistics.""" total_usage.prompt_tokens += delta_usage.prompt_tokens