diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py index 9be5be5c7c..e15ede15d2 100644 --- a/api/core/agent/agent_app_runner.py +++ b/api/core/agent/agent_app_runner.py @@ -108,7 +108,7 @@ class AgentAppRunner(BaseAgentRunner): current_agent_thought_id = None has_published_thought = False current_tool_name: str | None = None - self._current_message_file_ids = [] + self._current_message_file_ids: list[str] = [] # organize prompt messages prompt_messages = self._organize_prompt_messages() diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 54abef0552..3b82ec4a05 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1720,15 +1720,15 @@ class LLMNode(Node[LLMNodeData]): if meta and isinstance(meta, dict) and meta.get("error"): tool_error = meta.get("error") - tool_call_segment = tool_trace_map.get(tool_call_id) - if tool_call_segment is None: - tool_call_segment = LLMTraceSegment( - type="tool_call", - text=None, - tool_call_id=tool_call_id, - tool_name=tool_name, - tool_arguments=None, - ) + existing_tool_segment = tool_trace_map.get(tool_call_id) + tool_call_segment = existing_tool_segment or LLMTraceSegment( + type="tool_call", + text=None, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=None, + ) + if existing_tool_segment is None: trace_segments.append(tool_call_segment) if tool_call_id: tool_trace_map[tool_call_id] = tool_call_segment @@ -1854,18 +1854,18 @@ class LLMNode(Node[LLMNodeData]): reasoning_index = 0 content_position = 0 tool_call_seen_index: dict[str, int] = {} - for segment in trace_segments: - if segment.type == "thought": + for trace_segment in trace_segments: + if trace_segment.type == "thought": sequence.append({"type": "reasoning", "index": reasoning_index}) reasoning_index += 1 - elif segment.type == "content": - segment_text = segment.text or "" + elif trace_segment.type == "content": + segment_text = trace_segment.text or "" start = content_position end = start + len(segment_text) sequence.append({"type": "content", "start": start, "end": end}) content_position = end - elif segment.type == "tool_call": - tool_id = segment.tool_call_id or "" + elif trace_segment.type == "tool_call": + tool_id = trace_segment.tool_call_id or "" if tool_id not in tool_call_seen_index: tool_call_seen_index[tool_id] = len(tool_call_seen_index) sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) diff --git a/api/migrations/versions/2025_12_10_1617-85c8b4a64f53_add_llm_generation_detail_table.py b/api/migrations/versions/2025_12_10_1617-85c8b4a64f53_add_llm_generation_detail_table.py index 340cc82bb5..700f9ea80b 100644 --- a/api/migrations/versions/2025_12_10_1617-85c8b4a64f53_add_llm_generation_detail_table.py +++ b/api/migrations/versions/2025_12_10_1617-85c8b4a64f53_add_llm_generation_detail_table.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = '85c8b4a64f53' -down_revision = '7bb281b7a422' +down_revision = 'd57accd375ae' branch_labels = None depends_on = None diff --git a/api/models/workflow.py b/api/models/workflow.py index 853d5afefc..89ec0352df 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -57,6 +57,37 @@ from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) +def is_generation_outputs(outputs: Mapping[str, Any]) -> bool: + if not outputs: + return False + + allowed_sequence_types = {"reasoning", "content", "tool_call"} + + def valid_sequence_item(item: Mapping[str, Any]) -> bool: + return isinstance(item, Mapping) and item.get("type") in allowed_sequence_types + + def valid_value(value: Any) -> bool: + if not isinstance(value, Mapping): + return False + + content = value.get("content") + reasoning_content = value.get("reasoning_content") + tool_calls = value.get("tool_calls") + sequence = value.get("sequence") + + return ( + isinstance(content, str) + and isinstance(reasoning_content, list) + and all(isinstance(item, str) for item in reasoning_content) + and isinstance(tool_calls, list) + and all(isinstance(item, Mapping) for item in tool_calls) + and isinstance(sequence, list) + and all(valid_sequence_item(item) for item in sequence) + ) + + return all(valid_value(value) for value in outputs.values()) + + class WorkflowType(StrEnum): """ Workflow Type Enum @@ -652,6 +683,10 @@ class WorkflowRun(Base): def outputs_dict(self) -> Mapping[str, Any]: return json.loads(self.outputs) if self.outputs else {} + @property + def outputs_as_generation(self) -> bool: + return is_generation_outputs(self.outputs_dict) + @property def message(self): from .model import Message diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 3b6998d0b2..1bc821c43d 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,6 +1,5 @@ import threading -from collections.abc import Mapping, Sequence -from typing import Any +from collections.abc import Sequence from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker @@ -109,9 +108,6 @@ class WorkflowRunService: run_id=run_id, ) - if workflow_run: - workflow_run.outputs_as_generation = self._are_all_generation_outputs(workflow_run.outputs_dict) - return workflow_run def get_workflow_runs_count( @@ -165,32 +161,3 @@ class WorkflowRunService: app_id=app_model.id, workflow_run_id=run_id, ) - - @staticmethod - def _are_all_generation_outputs(outputs: Mapping[str, Any]) -> bool: - if not outputs: - return False - - allowed_sequence_types = {"reasoning", "content", "tool_call"} - - for value in outputs.values(): - if not isinstance(value, Mapping): - return False - - content = value.get("content") - reasoning_content = value.get("reasoning_content") - tool_calls = value.get("tool_calls") - sequence = value.get("sequence") - - if not isinstance(content, str): - return False - if not isinstance(reasoning_content, list) or any(not isinstance(item, str) for item in reasoning_content): - return False - if not isinstance(tool_calls, list) or any(not isinstance(item, Mapping) for item in tool_calls): - return False - if not isinstance(sequence, list) or any( - not isinstance(item, Mapping) or item.get("type") not in allowed_sequence_types for item in sequence - ): - return False - - return True diff --git a/api/tests/unit_tests/core/agent/__init__.py b/api/tests/unit_tests/core/agent/__init__.py new file mode 100644 index 0000000000..a9ccd45f4b --- /dev/null +++ b/api/tests/unit_tests/core/agent/__init__.py @@ -0,0 +1,3 @@ +""" +Mark agent test modules as a package to avoid import name collisions. +"""