From 1e10bf525c199422cb22b6cfda0fad46201e499c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 10 Jan 2026 16:17:45 +0800 Subject: [PATCH] refactor(models): Refine MessageAgentThought SQLAlchemy typing (#27749) Co-authored-by: Asuka Minato Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/agent/base_agent_runner.py | 53 ++++++++++------ api/models/model.py | 62 +++++++++++-------- .../services/test_agent_service.py | 7 --- 3 files changed, 68 insertions(+), 54 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c196dbbdf1..3c6d36afe4 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from decimal import Decimal from typing import Union, cast from sqlalchemy import select @@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) @@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner): thought = MessageAgentThought( message_id=message_id, message_chain_id=None, + tool_process_data=None, thought="", tool=tool_name, tool_labels_str="{}", @@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner): tool_input=tool_input, message=message, message_token=0, - message_unit_price=0, - message_price_unit=0, + message_unit_price=Decimal(0), + message_price_unit=Decimal("0.001"), message_files=json.dumps(messages_ids) if messages_ids else "", answer="", observation="", answer_token=0, - answer_unit_price=0, - answer_price_unit=0, + answer_unit_price=Decimal(0), + answer_price_unit=Decimal("0.001"), tokens=0, - total_price=0, + total_price=Decimal(0), position=self.agent_thought_count + 1, currency="USD", latency=0, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=self.user_id, ) @@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner): raise ValueError("agent thought not found") if thought: - agent_thought.thought += thought + existing_thought = agent_thought.thought or "" + agent_thought.thought = f"{existing_thought}{thought}" if tool_name: agent_thought.tool = tool_name @@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner): agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: - tools = agent_thought.tool - if tools: - tools = tools.split(";") + tool_names_raw = agent_thought.tool + if tool_names_raw: + tool_names = tool_names_raw.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] - try: - tool_inputs = json.loads(agent_thought.tool_input) - except Exception: - tool_inputs = {tool: {} for tool in tools} - try: - tool_responses = json.loads(agent_thought.observation) - except Exception: - tool_responses = dict.fromkeys(tools, agent_thought.observation) + tool_input_payload = agent_thought.tool_input + if tool_input_payload: + try: + tool_inputs = json.loads(tool_input_payload) + except Exception: + tool_inputs = {tool: {} for tool in tool_names} + else: + tool_inputs = {tool: {} for tool in tool_names} - for tool in tools: + observation_payload = agent_thought.observation + if observation_payload: + try: + tool_responses = json.loads(observation_payload) + except Exception: + tool_responses = dict.fromkeys(tool_names, observation_payload) + else: + tool_responses = dict.fromkeys(tool_names, observation_payload) + + for tool in tool_names: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) tool_calls.append( @@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner): *tool_call_response, ] ) - if not tools: + if not tool_names_raw: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: if message.answer: diff --git a/api/models/model.py b/api/models/model.py index c791ae15b0..a48f4d34d4 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1843,7 +1843,7 @@ class MessageChain(TypeBase): ) -class MessageAgentThought(Base): +class MessageAgentThought(TypeBase): __tablename__ = "message_agent_thoughts" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), @@ -1851,34 +1851,42 @@ class MessageAgentThought(Base): sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - message_id = mapped_column(StringUUID, nullable=False) - message_chain_id = mapped_column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - thought = mapped_column(LongText, nullable=True) - tool = mapped_column(LongText, nullable=True) - tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) - tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) - tool_input = mapped_column(LongText, nullable=True) - observation = mapped_column(LongText, nullable=True) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) + tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) + tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design - tool_process_data = mapped_column(LongText, nullable=True) - message = mapped_column(LongText, nullable=True) - message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - message_unit_price = mapped_column(sa.Numeric, nullable=True) - message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - message_files = mapped_column(LongText, nullable=True) - answer = mapped_column(LongText, nullable=True) - answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - answer_unit_price = mapped_column(sa.Numeric, nullable=True) - answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - total_price = mapped_column(sa.Numeric, nullable=True) - currency = mapped_column(String(255), nullable=True) - latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) - created_by_role = mapped_column(String(255), nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) + tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + message_price_unit: Mapped[Decimal] = mapped_column( + sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001") + ) + message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + answer_price_unit: Mapped[Decimal] = mapped_column( + sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001") + ) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp() + ) @property def files(self) -> list[Any]: diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 3be2798085..a22d6f8fbf 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -230,7 +230,6 @@ class TestAgentService: # Create first agent thought thought1 = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", @@ -257,7 +256,6 @@ class TestAgentService: # Create second agent thought thought2 = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=2, thought="Based on the analysis, I can provide a response", @@ -545,7 +543,6 @@ class TestAgentService: # Create agent thought with tool error thought_with_error = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", @@ -759,7 +756,6 @@ class TestAgentService: # Create agent thought with multiple tools complex_thought = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to use multiple tools to complete this task", @@ -877,7 +873,6 @@ class TestAgentService: # Create agent thought with files thought_with_files = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to process some files", @@ -957,7 +952,6 @@ class TestAgentService: # Create agent thought with empty tool data empty_thought = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", @@ -999,7 +993,6 @@ class TestAgentService: # Create agent thought with malformed JSON malformed_thought = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request",