refactor(models): Refine MessageAgentThought SQLAlchemy typing (#27749)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2026-01-10 16:17:45 +08:00 committed by GitHub
parent 8b1af36d94
commit 1e10bf525c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 54 deletions

View File

@ -1,6 +1,7 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
tokens=0,
total_price=0,
total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}
for tool in tools:
observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)
for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:

View File

@ -1843,7 +1843,7 @@ class MessageChain(TypeBase):
)
class MessageAgentThought(Base):
class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@ -1851,34 +1851,42 @@ class MessageAgentThought(Base):
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
thought = mapped_column(LongText, nullable=True)
tool = mapped_column(LongText, nullable=True)
tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input = mapped_column(LongText, nullable=True)
observation = mapped_column(LongText, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(LongText, nullable=True)
message = mapped_column(LongText, nullable=True)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(LongText, nullable=True)
answer = mapped_column(LongText, nullable=True)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String(255), nullable=True)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
message_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
)
@property
def files(self) -> list[Any]:

View File

@ -230,7 +230,6 @@ class TestAgentService:
# Create first agent thought
thought1 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
@ -257,7 +256,6 @@ class TestAgentService:
# Create second agent thought
thought2 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=2,
thought="Based on the analysis, I can provide a response",
@ -545,7 +543,6 @@ class TestAgentService:
# Create agent thought with tool error
thought_with_error = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
@ -759,7 +756,6 @@ class TestAgentService:
# Create agent thought with multiple tools
complex_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to use multiple tools to complete this task",
@ -877,7 +873,6 @@ class TestAgentService:
# Create agent thought with files
thought_with_files = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to process some files",
@ -957,7 +952,6 @@ class TestAgentService:
# Create agent thought with empty tool data
empty_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
@ -999,7 +993,6 @@ class TestAgentService:
# Create agent thought with malformed JSON
malformed_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",