diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 56319a14a3..46af4d2d72 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -160,6 +160,8 @@ class AgentLog(BaseModel): PROVIDER = "provider" CURRENCY = "currency" LLM_USAGE = "llm_usage" + ICON = "icon" + ICON_DARK = "icon_dark" class LogStatus(StrEnum): START = "start" diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py index d797586e5f..d98fa005a3 100644 --- a/api/core/agent/patterns/base.py +++ b/api/core/agent/patterns/base.py @@ -155,6 +155,35 @@ class AgentPattern(ABC): return " ".join(text_parts) return "" + def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]: + """Get metadata for a tool including provider and icon info.""" + from core.tools.tool_manager import ToolManager + + metadata: dict[AgentLog.LogMetadata, Any] = {} + if tool_instance.entity and tool_instance.entity.identity: + identity = tool_instance.entity.identity + if identity.provider: + metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider + + # Get icon using ToolManager for proper URL generation + tenant_id = self.context.tenant_id + if tenant_id and identity.provider: + try: + provider_type = tool_instance.tool_provider_type() + icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider) + if isinstance(icon, str): + metadata[AgentLog.LogMetadata.ICON] = icon + elif isinstance(icon, dict): + # Handle icon dict with background/content or light/dark variants + metadata[AgentLog.LogMetadata.ICON] = icon + except Exception: + # Fallback to identity.icon if ToolManager fails + if identity.icon: + metadata[AgentLog.LogMetadata.ICON] = identity.icon + elif identity.icon: + metadata[AgentLog.LogMetadata.ICON] = identity.icon + return metadata + def _create_log( self, label: str, @@ -165,7 +194,7 @@ class AgentPattern(ABC): extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None, ) -> AgentLog: """Create a new AgentLog with standard metadata.""" - metadata = { + metadata: dict[AgentLog.LogMetadata, Any] = { AgentLog.LogMetadata.STARTED_AT: time.perf_counter(), } if extra_metadata: diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index a46c5d77f9..a0f0d30ab3 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -235,6 +235,9 @@ class FunctionCallStrategy(AgentPattern): if not tool_instance: raise ValueError(f"Tool {tool_name} not found") + # Get tool metadata (provider, icon, etc.) + tool_metadata = self._get_tool_metadata(tool_instance) + # Create tool call log tool_call_log = self._create_log( label=f"CALL {tool_name}", @@ -246,6 +249,7 @@ class FunctionCallStrategy(AgentPattern): "tool_args": tool_args, }, parent_id=round_log.id, + extra_metadata=tool_metadata, ) yield tool_call_log diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py index 81aa7fe3b1..87a9fa9b65 100644 --- a/api/core/agent/patterns/react.py +++ b/api/core/agent/patterns/react.py @@ -347,7 +347,11 @@ class ReActStrategy(AgentPattern): tool_name = action.action_name tool_args: dict[str, Any] | str = action.action_input - # Start tool log + # Find tool instance first to get metadata + tool_instance = self._find_tool_by_name(tool_name) + tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {} + + # Start tool log with tool metadata tool_log = self._create_log( label=f"CALL {tool_name}", log_type=AgentLog.LogType.TOOL_CALL, @@ -357,11 +361,10 @@ class ReActStrategy(AgentPattern): "tool_args": tool_args, }, parent_id=round_log.id, + extra_metadata=tool_metadata, ) yield tool_log - # Find tool instance - tool_instance = self._find_tool_by_name(tool_name) if not tool_instance: # Finish tool log with error yield self._finish_log( diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a7c6426dc3..c4d89b8b2f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -526,6 +526,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else "" tool_files = tool_result.files if tool_result else [] tool_elapsed_time = tool_result.elapsed_time if tool_result else None + tool_icon = tool_payload.icon if tool_payload else None + tool_icon_dark = tool_payload.icon_dark if tool_payload else None # Record stream event based on chunk type chunk_type = event.chunk_type or ChunkType.TEXT match chunk_type: @@ -548,7 +550,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): tool_elapsed_time=tool_elapsed_time, ) self._task_state.answer += delta_text - + case _: + pass yield self._message_cycle_manager.message_to_stream_response( answer=delta_text, message_id=self._message_id, @@ -559,6 +562,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): tool_arguments=tool_arguments or None, tool_files=tool_files, tool_elapsed_time=tool_elapsed_time, + tool_icon=tool_icon, + tool_icon_dark=tool_icon_dark, ) def _handle_iteration_start_event( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index d9e1dfb474..2b8ed38c63 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -492,6 +492,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): tool_arguments = tool_call.arguments if tool_call else None tool_elapsed_time = tool_result.elapsed_time if tool_result else None tool_files = tool_result.files if tool_result else [] + tool_icon = tool_payload.icon if tool_payload else None + tool_icon_dark = tool_payload.icon_dark if tool_payload else None # only publish tts message at text chunk streaming if tts_publisher and queue_message: @@ -506,6 +508,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): tool_arguments=tool_arguments, tool_files=tool_files, tool_elapsed_time=tool_elapsed_time, + tool_icon=tool_icon, + tool_icon_dark=tool_icon_dark, ) def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: @@ -679,6 +683,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): tool_files: list[str] | None = None, tool_error: str | None = None, tool_elapsed_time: float | None = None, + tool_icon: str | dict | None = None, + tool_icon_dark: str | dict | None = None, ) -> TextChunkStreamResponse: """ Handle completed event. @@ -701,6 +707,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): "tool_call_id": tool_call_id, "tool_name": tool_name, "tool_arguments": tool_arguments, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, } ) elif response_chunk_type == ResponseChunkType.TOOL_RESULT: @@ -712,6 +720,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): "tool_files": tool_files, "tool_error": tool_error, "tool_elapsed_time": tool_elapsed_time, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, } ) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 068856b947..0998510b60 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -132,6 +132,10 @@ class MessageStreamResponse(StreamResponse): """error message if tool failed""" tool_elapsed_time: float | None = None """elapsed time spent executing the tool""" + tool_icon: str | dict | None = None + """icon of the tool""" + tool_icon_dark: str | dict | None = None + """dark theme icon of the tool""" def model_dump(self, *args, **kwargs) -> dict[str, object]: kwargs.setdefault("exclude_none", True) diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index f07d42ebb9..f9f341fcea 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -239,6 +239,8 @@ class MessageCycleManager: tool_files: list[str] | None = None, tool_error: str | None = None, tool_elapsed_time: float | None = None, + tool_icon: str | dict | None = None, + tool_icon_dark: str | dict | None = None, event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ @@ -271,6 +273,8 @@ class MessageCycleManager: "tool_call_id": tool_call_id, "tool_name": tool_name, "tool_arguments": tool_arguments, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, } ) elif chunk_type == "tool_result": @@ -282,6 +286,8 @@ class MessageCycleManager: "tool_files": tool_files, "tool_error": tool_error, "tool_elapsed_time": tool_elapsed_time, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, } ) diff --git a/api/core/workflow/entities/tool_entities.py b/api/core/workflow/entities/tool_entities.py index 9fdd895517..7e71a86849 100644 --- a/api/core/workflow/entities/tool_entities.py +++ b/api/core/workflow/entities/tool_entities.py @@ -14,6 +14,8 @@ class ToolCall(BaseModel): id: str | None = Field(default=None, description="Unique identifier for this tool call") name: str | None = Field(default=None, description="Name of the tool being called") arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + icon: str | dict | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool") class ToolResult(BaseModel): @@ -23,6 +25,8 @@ class ToolResult(BaseModel): files: list[str] = Field(default_factory=list, description="File produced by tool") status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool") + icon: str | dict | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool") class ToolCallResult(BaseModel): diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 1dbd2408d0..f441efa8e2 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -10,7 +10,7 @@ from core.model_runtime.entities import ImagePromptMessageContent, LLMMode from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.entities import ToolCallResult +from core.workflow.entities import ToolCall, ToolCallResult from core.workflow.node_events import AgentLogEvent from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector @@ -89,24 +89,44 @@ class ToolMetadata(BaseModel): extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description") +class ModelTraceSegment(BaseModel): + """Model invocation trace segment with token usage and output.""" + + text: str | None = Field(None, description="Model output text content") + reasoning: str | None = Field(None, description="Reasoning/thought content from model") + tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model") + + +class ToolTraceSegment(BaseModel): + """Tool invocation trace segment with call details and result.""" + + id: str | None = Field(default=None, description="Unique identifier for this tool call") + name: str | None = Field(default=None, description="Name of the tool being called") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + output: str | None = Field(default=None, description="Tool call result") + + class LLMTraceSegment(BaseModel): """ Streaming trace segment for LLM tool-enabled runs. - Order is preserved for replay. Tool calls are single entries containing both - arguments and results. + Represents alternating model and tool invocations in sequence: + model -> tool -> model -> tool -> ... + + Each segment records its execution duration. """ - type: Literal["thought", "content", "tool_call"] + type: Literal["model", "tool"] + duration: float = Field(..., description="Execution duration in seconds") + usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call") + output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment") - # Common optional fields - text: str | None = Field(None, description="Text chunk for thought/content") - - # Tool call fields (combined start + result) - tool_call: ToolCallResult | None = Field( - default=None, - description="Combined tool call arguments and result for this segment", - ) + # Common metadata for both model and tool segments + provider: str | None = Field(default=None, description="Model or tool provider identifier") + icon: str | None = Field(default=None, description="Icon for the provider") + icon_dark: str | None = Field(default=None, description="Dark theme icon for the provider") + error: str | None = Field(default=None, description="Error message if segment failed") + status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status") class LLMGenerationData(BaseModel): @@ -233,6 +253,7 @@ class StreamBuffers(BaseModel): think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser) pending_thought: list[str] = Field(default_factory=list) pending_content: list[str] = Field(default_factory=list) + pending_tool_calls: list[ToolCall] = Field(default_factory=list) current_turn_reasoning: list[str] = Field(default_factory=list) reasoning_per_turn: list[str] = Field(default_factory=list) @@ -241,6 +262,8 @@ class TraceState(BaseModel): trace_segments: list[LLMTraceSegment] = Field(default_factory=list) tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict) tool_call_index_map: dict[str, int] = Field(default_factory=dict) + model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment") + pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment") class AggregatedResult(BaseModel): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 798ae1376f..a8a9c9063c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -99,10 +99,12 @@ from .entities import ( LLMNodeData, LLMTraceSegment, ModelConfig, + ModelTraceSegment, StreamBuffers, ThinkTagStreamParser, ToolLogPayload, ToolOutputState, + ToolTraceSegment, TraceState, ) from .exc import ( @@ -1678,17 +1680,71 @@ class LLMNode(Node[LLMNodeData]): "elapsed_time": tool_call.elapsed_time, } - def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None: - if not buffers.pending_thought: - return - trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought))) - buffers.pending_thought.clear() + def _generate_model_provider_icon_url(self, provider: str, dark: bool = False) -> str | None: + """Generate icon URL for model provider.""" + from yarl import URL - def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None: - if not buffers.pending_content: + from configs import dify_config + + icon_type = "icon_small_dark" if dark else "icon_small" + try: + return str( + URL(dify_config.CONSOLE_API_URL or "/") + / "console" + / "api" + / "workspaces" + / "current" + / "model-providers" + / provider + / icon_type + / "en_US" + ) + except Exception: + return None + + def _flush_model_segment( + self, + buffers: StreamBuffers, + trace_state: TraceState, + error: str | None = None, + ) -> None: + """Flush pending thought/content buffers into a single model trace segment.""" + if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls: return - trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content))) + + now = time.perf_counter() + duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0 + + # Use pending_usage from trace_state (captured from THOUGHT log) + usage = trace_state.pending_usage + + # Generate model provider icon URL + provider = self._node_data.model.provider + model_icon = self._generate_model_provider_icon_url(provider) + model_icon_dark = self._generate_model_provider_icon_url(provider, dark=True) + + trace_state.trace_segments.append( + LLMTraceSegment( + type="model", + duration=duration, + usage=usage, + output=ModelTraceSegment( + text="".join(buffers.pending_content) if buffers.pending_content else None, + reasoning="".join(buffers.pending_thought) if buffers.pending_thought else None, + tool_calls=list(buffers.pending_tool_calls), + ), + provider=provider, + icon=model_icon, + icon_dark=model_icon_dark, + error=error, + status="error" if error else "success", + ) + ) + buffers.pending_thought.clear() buffers.pending_content.clear() + buffers.pending_tool_calls.clear() + trace_state.model_segment_start_time = None + trace_state.pending_usage = None def _handle_agent_log_output( self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext @@ -1716,30 +1772,26 @@ class LLMNode(Node[LLMNodeData]): else: agent_context.agent_logs.append(agent_log_event) + # Handle THOUGHT log completion - capture usage for model segment + if output.log_type == AgentLog.LogType.THOUGHT and output.status == AgentLog.LogStatus.SUCCESS: + llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) if output.metadata else None + if llm_usage: + trace_state.pending_usage = llm_usage + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: tool_name = payload.tool_name tool_call_id = payload.tool_call_id tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else "" + # Get icon from metadata (available at START) + tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None + tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) - self._flush_thought_segment(buffers, trace_state) - self._flush_content_segment(buffers, trace_state) - - tool_call_segment = LLMTraceSegment( - type="tool_call", - text=None, - tool_call=ToolCallResult( - id=tool_call_id, - name=tool_name, - arguments=tool_arguments, - elapsed_time=output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None, - ), - ) - trace_state.trace_segments.append(tool_call_segment) - if tool_call_id: - trace_state.tool_trace_map[tool_call_id] = tool_call_segment + # Add tool call to pending list for model segment + buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments)) yield ToolCallChunkEvent( selector=[self._node_id, "generation", "tool_calls"], @@ -1748,6 +1800,8 @@ class LLMNode(Node[LLMNodeData]): id=tool_call_id, name=tool_name, arguments=tool_arguments, + icon=tool_icon, + icon_dark=tool_icon_dark, ), is_final=False, ) @@ -1758,12 +1812,13 @@ class LLMNode(Node[LLMNodeData]): tool_call_id = payload.tool_call_id tool_files = payload.files if isinstance(payload.files, list) else [] tool_error = payload.tool_error + tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else "" if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) - self._flush_thought_segment(buffers, trace_state) - self._flush_content_segment(buffers, trace_state) + # Flush model segment before tool result processing + self._flush_model_segment(buffers, trace_state) if output.status == AgentLog.LogStatus.ERROR: tool_error = output.error or payload.tool_error @@ -1775,47 +1830,48 @@ class LLMNode(Node[LLMNodeData]): if meta_error: tool_error = meta_error - existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id) - tool_call_segment = existing_tool_segment or LLMTraceSegment( - type="tool_call", - text=None, - tool_call=ToolCallResult( + elapsed_time = output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None + tool_provider = output.metadata.get(AgentLog.LogMetadata.PROVIDER) if output.metadata else None + tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None + tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None + result_str = str(tool_output) if tool_output is not None else None + + tool_status: Literal["success", "error"] = "error" if tool_error else "success" + tool_call_segment = LLMTraceSegment( + type="tool", + duration=elapsed_time or 0.0, + usage=None, + output=ToolTraceSegment( id=tool_call_id, name=tool_name, - arguments=None, - elapsed_time=output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None, + arguments=tool_arguments, + output=result_str, ), + provider=tool_provider, + icon=tool_icon, + icon_dark=tool_icon_dark, + error=str(tool_error) if tool_error else None, + status=tool_status, ) - if existing_tool_segment is None: - trace_state.trace_segments.append(tool_call_segment) - if tool_call_id: - trace_state.tool_trace_map[tool_call_id] = tool_call_segment + trace_state.trace_segments.append(tool_call_segment) + if tool_call_id: + trace_state.tool_trace_map[tool_call_id] = tool_call_segment - if tool_call_segment.tool_call is None: - tool_call_segment.tool_call = ToolCallResult( - id=tool_call_id, - name=tool_name, - arguments=None, - elapsed_time=output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None, - ) - tool_call_segment.tool_call.output = ( - str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None - ) - tool_call_segment.tool_call.files = [] - tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS - - result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None + # Start new model segment tracking + trace_state.model_segment_start_time = time.perf_counter() yield ToolResultChunkEvent( selector=[self._node_id, "generation", "tool_results"], - chunk=result_output or "", + chunk=result_str or "", tool_result=ToolResult( id=tool_call_id, name=tool_name, - output=result_output, + output=result_str, files=tool_files, status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, - elapsed_time=output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None, + elapsed_time=elapsed_time, + icon=tool_icon, + icon_dark=tool_icon_dark, ), is_final=False, ) @@ -1840,15 +1896,17 @@ class LLMNode(Node[LLMNodeData]): if not segment and kind not in {"thought_start", "thought_end"}: continue + # Start tracking model segment time on first output + if trace_state.model_segment_start_time is None: + trace_state.model_segment_start_time = time.perf_counter() + if kind == "thought_start": - self._flush_content_segment(buffers, trace_state) yield ThoughtStartChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) elif kind == "thought": - self._flush_content_segment(buffers, trace_state) buffers.current_turn_reasoning.append(segment) buffers.pending_thought.append(segment) yield ThoughtChunkEvent( @@ -1857,14 +1915,12 @@ class LLMNode(Node[LLMNodeData]): is_final=False, ) elif kind == "thought_end": - self._flush_thought_segment(buffers, trace_state) yield ThoughtEndChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) else: - self._flush_thought_segment(buffers, trace_state) aggregate.text += segment buffers.pending_content.append(segment) yield StreamChunkEvent( @@ -1890,15 +1946,18 @@ class LLMNode(Node[LLMNodeData]): for kind, segment in buffers.think_parser.flush(): if not segment and kind not in {"thought_start", "thought_end"}: continue + + # Start tracking model segment time on first output + if trace_state.model_segment_start_time is None: + trace_state.model_segment_start_time = time.perf_counter() + if kind == "thought_start": - self._flush_content_segment(buffers, trace_state) yield ThoughtStartChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) elif kind == "thought": - self._flush_content_segment(buffers, trace_state) buffers.current_turn_reasoning.append(segment) buffers.pending_thought.append(segment) yield ThoughtChunkEvent( @@ -1907,14 +1966,12 @@ class LLMNode(Node[LLMNodeData]): is_final=False, ) elif kind == "thought_end": - self._flush_thought_segment(buffers, trace_state) yield ThoughtEndChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) else: - self._flush_thought_segment(buffers, trace_state) aggregate.text += segment buffers.pending_content.append(segment) yield StreamChunkEvent( @@ -1931,8 +1988,13 @@ class LLMNode(Node[LLMNodeData]): if buffers.current_turn_reasoning: buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) - self._flush_thought_segment(buffers, trace_state) - self._flush_content_segment(buffers, trace_state) + # For final flush, use aggregate.usage if pending_usage is not set + # (e.g., for simple LLM calls without tool invocations) + if trace_state.pending_usage is None: + trace_state.pending_usage = aggregate.usage + + # Flush final model segment + self._flush_model_segment(buffers, trace_state) def _close_streams(self) -> Generator[NodeEventBase, None, None]: yield StreamChunkEvent(