diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 068ea5ecf0..82f146414b 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -52,6 +52,7 @@ class LLMGenerationData(BaseModel): text: str = Field(..., description="Accumulated text content from all turns") reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") tool_calls: list[dict[str, Any]] = Field(default_factory=list, description="Tool calls with results") + sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") usage: LLMUsage = Field(..., description="LLM usage statistics") finish_reason: str | None = Field(None, description="Finish reason from LLM") files: list[File] = Field(default_factory=list, description="Generated files") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 5f4d938773..7fd74babe2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -392,6 +392,7 @@ class LLMNode(Node[LLMNodeData]): "content": generation_data.text, "reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...] "tool_calls": generation_data.tool_calls, + "sequence": generation_data.sequence, } files_to_output = generation_data.files else: @@ -400,6 +401,7 @@ class LLMNode(Node[LLMNodeData]): "content": clean_text, "reasoning_content": [reasoning_content] if reasoning_content else [], "tool_calls": [], + "sequence": [], } files_to_output = self._file_outputs @@ -428,22 +430,24 @@ class LLMNode(Node[LLMNodeData]): is_final=True, ) + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + } + + if generation_data and generation_data.trace: + metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [ + segment.model_dump() for segment in generation_data.trace + ] + yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, process_data=process_data, outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.LLM_TRACE: [ - segment.model_dump() for segment in generation_data.trace - ] - if generation_data - else [], - }, + metadata=metadata, llm_usage=usage, ) ) @@ -1783,6 +1787,27 @@ class LLMNode(Node[LLMNodeData]): _flush_thought() _flush_content() + # Build sequence from trace_segments for rendering + sequence: list[dict[str, Any]] = [] + reasoning_index = 0 + content_position = 0 + tool_call_seen_index: dict[str, int] = {} + for segment in trace_segments: + if segment.type == "thought": + sequence.append({"type": "reasoning", "index": reasoning_index}) + reasoning_index += 1 + elif segment.type == "content": + segment_text = segment.text or "" + start = content_position + end = start + len(segment_text) + sequence.append({"type": "content", "start": start, "end": end}) + content_position = end + elif segment.type == "tool_call": + tool_id = segment.tool_call_id or "" + if tool_id not in tool_call_seen_index: + tool_call_seen_index[tool_id] = len(tool_call_seen_index) + sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) + # Send final events for all streams yield StreamChunkEvent( selector=[self._node_id, "text"], @@ -1850,6 +1875,7 @@ class LLMNode(Node[LLMNodeData]): text=text, reasoning_contents=reasoning_per_turn, tool_calls=tool_calls_for_generation, + sequence=sequence, usage=usage, finish_reason=finish_reason, files=files, diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 6305d8d9d5..7b878e05c8 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -81,6 +81,7 @@ workflow_run_detail_fields = { "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, "outputs": fields.Raw(attribute="outputs_dict"), + "outputs_as_generation": fields.Boolean, "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index b903d8df5f..3b6998d0b2 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,5 +1,6 @@ import threading -from collections.abc import Sequence +from collections.abc import Mapping, Sequence +from typing import Any from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker @@ -102,12 +103,17 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - return self._workflow_run_repo.get_workflow_run_by_id( + workflow_run = self._workflow_run_repo.get_workflow_run_by_id( tenant_id=app_model.tenant_id, app_id=app_model.id, run_id=run_id, ) + if workflow_run: + workflow_run.outputs_as_generation = self._are_all_generation_outputs(workflow_run.outputs_dict) + + return workflow_run + def get_workflow_runs_count( self, app_model: App, @@ -159,3 +165,32 @@ class WorkflowRunService: app_id=app_model.id, workflow_run_id=run_id, ) + + @staticmethod + def _are_all_generation_outputs(outputs: Mapping[str, Any]) -> bool: + if not outputs: + return False + + allowed_sequence_types = {"reasoning", "content", "tool_call"} + + for value in outputs.values(): + if not isinstance(value, Mapping): + return False + + content = value.get("content") + reasoning_content = value.get("reasoning_content") + tool_calls = value.get("tool_calls") + sequence = value.get("sequence") + + if not isinstance(content, str): + return False + if not isinstance(reasoning_content, list) or any(not isinstance(item, str) for item in reasoning_content): + return False + if not isinstance(tool_calls, list) or any(not isinstance(item, Mapping) for item in tool_calls): + return False + if not isinstance(sequence, list) or any( + not isinstance(item, Mapping) or item.get("type") not in allowed_sequence_types for item in sequence + ): + return False + + return True