diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c196dbbdf1..825ac9babf 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from decimal import Decimal from typing import Union, cast from sqlalchemy import select @@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) @@ -287,8 +289,10 @@ class BaseAgentRunner(AppRunner): Create agent thought """ thought = MessageAgentThought( + id=str(uuid.uuid4()), message_id=message_id, message_chain_id=None, + tool_process_data=None, thought="", tool=tool_name, tool_labels_str="{}", @@ -296,20 +300,20 @@ class BaseAgentRunner(AppRunner): tool_input=tool_input, message=message, message_token=0, - message_unit_price=0, - message_price_unit=0, + message_unit_price=Decimal(0), + message_price_unit=Decimal("0.001"), message_files=json.dumps(messages_ids) if messages_ids else "", answer="", observation="", answer_token=0, - answer_unit_price=0, - answer_price_unit=0, + answer_unit_price=Decimal(0), + answer_price_unit=Decimal("0.001"), tokens=0, - total_price=0, + total_price=Decimal(0), position=self.agent_thought_count + 1, currency="USD", latency=0, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=self.user_id, ) diff --git a/api/models/model.py b/api/models/model.py index f698b79d32..a65f72ccf6 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1779,34 +1779,40 @@ class MessageAgentThought(Base): sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) - message_id = mapped_column(StringUUID, nullable=False) - message_chain_id = mapped_column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - thought = mapped_column(sa.Text, nullable=True) - tool = mapped_column(sa.Text, nullable=True) - tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) - tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) - tool_input = mapped_column(sa.Text, nullable=True) - observation = mapped_column(sa.Text, nullable=True) + tool_labels_str: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_meta_str: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + created_by_role: Mapped[CreatorUserRole] = mapped_column(String, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp() + ) + message_price_unit: Mapped[Decimal] = mapped_column( + sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001") + ) + answer_price_unit: Mapped[Decimal] = mapped_column( + sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001") + ) + message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + thought: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + tool: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + tool_input: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + observation: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design - tool_process_data = mapped_column(sa.Text, nullable=True) - message = mapped_column(sa.Text, nullable=True) - message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - message_unit_price = mapped_column(sa.Numeric, nullable=True) - message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - message_files = mapped_column(sa.Text, nullable=True) - answer = mapped_column(sa.Text, nullable=True) - answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - answer_unit_price = mapped_column(sa.Numeric, nullable=True) - answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - total_price = mapped_column(sa.Numeric, nullable=True) - currency = mapped_column(String, nullable=True) - latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) - created_by_role = mapped_column(String, nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) + tool_process_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + message: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + message_files: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + answer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + currency: Mapped[str | None] = mapped_column(String, nullable=True, default=None) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) @property def files(self) -> list[Any]: