diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1f3c218d59..ad9b625350 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -280,7 +280,7 @@ class BaseAgentRunner(AppRunner): def create_agent_thought( self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] - ) -> MessageAgentThought: + ) -> str: """ Create agent thought """ @@ -313,16 +313,15 @@ class BaseAgentRunner(AppRunner): db.session.add(thought) db.session.commit() - db.session.refresh(thought) + agent_thought_id = str(thought.id) + self.agent_thought_count += 1 db.session.close() - self.agent_thought_count += 1 - - return thought + return agent_thought_id def save_agent_thought( self, - agent_thought: MessageAgentThought, + agent_thought_id: str, tool_name: str | None, tool_input: Union[str, dict, None], thought: str | None, @@ -335,12 +334,9 @@ class BaseAgentRunner(AppRunner): """ Save agent thought """ - updated_agent_thought = ( - db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() - ) - if not updated_agent_thought: + agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() + if not agent_thought: raise ValueError("agent thought not found") - agent_thought = updated_agent_thought if thought: agent_thought.thought += thought @@ -355,7 +351,7 @@ class BaseAgentRunner(AppRunner): except Exception: tool_input = json.dumps(tool_input) - updated_agent_thought.tool_input = tool_input + agent_thought.tool_input = tool_input if observation: if isinstance(observation, dict): @@ -364,27 +360,27 @@ class BaseAgentRunner(AppRunner): except Exception: observation = json.dumps(observation) - updated_agent_thought.observation = observation + agent_thought.observation = observation if answer: agent_thought.answer = answer if messages_ids is not None and len(messages_ids) > 0: - updated_agent_thought.message_files = json.dumps(messages_ids) + agent_thought.message_files = json.dumps(messages_ids) if llm_usage: - updated_agent_thought.message_token = llm_usage.prompt_tokens - updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit - updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price - updated_agent_thought.answer_token = llm_usage.completion_tokens - updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit - updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price - updated_agent_thought.tokens = llm_usage.total_tokens - updated_agent_thought.total_price = llm_usage.total_price + agent_thought.message_token = llm_usage.prompt_tokens + agent_thought.message_price_unit = llm_usage.prompt_price_unit + agent_thought.message_unit_price = llm_usage.prompt_unit_price + agent_thought.answer_token = llm_usage.completion_tokens + agent_thought.answer_price_unit = llm_usage.completion_price_unit + agent_thought.answer_unit_price = llm_usage.completion_unit_price + agent_thought.tokens = llm_usage.total_tokens + agent_thought.total_price = llm_usage.total_price # check if tool labels is not empty - labels = updated_agent_thought.tool_labels or {} - tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] + labels = agent_thought.tool_labels or {} + tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue @@ -395,7 +391,7 @@ class BaseAgentRunner(AppRunner): else: labels[tool] = {"en_US": tool, "zh_Hans": tool} - updated_agent_thought.tool_labels_str = json.dumps(labels) + agent_thought.tool_labels_str = json.dumps(labels) if tool_invoke_meta is not None: if isinstance(tool_invoke_meta, dict): @@ -404,7 +400,7 @@ class BaseAgentRunner(AppRunner): except Exception: tool_invoke_meta = json.dumps(tool_invoke_meta) - updated_agent_thought.tool_meta_str = tool_invoke_meta + agent_thought.tool_meta_str = tool_invoke_meta db.session.commit() db.session.close() diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 4979f63432..565fb42478 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -97,13 +97,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): message_file_ids: list[str] = [] - agent_thought = self.create_agent_thought( + agent_thought_id = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) # recalc llm max tokens @@ -133,7 +133,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish agent thought if it's first iteration if iteration_step == 1: self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) for chunk in react_chunks: @@ -168,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, @@ -181,7 +181,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not scratchpad.is_final(): self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) if not scratchpad.action: @@ -212,7 +212,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): scratchpad.agent_response = tool_invoke_response self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought or "", @@ -224,7 +224,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): ) self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) # update prompt tool message @@ -244,7 +244,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name="", tool_input={}, tool_invoke_meta={}, diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5491689ece..4df71ce9de 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -80,7 +80,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_messages_tools = [] message_file_ids: list[str] = [] - agent_thought = self.create_agent_thought( + agent_thought_id = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -114,7 +114,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): for chunk in chunks: if is_first_chunk: self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) is_first_chunk = False # check if there is any tool call @@ -172,7 +172,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): result.message.content = "" self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) yield LLMResultChunk( @@ -205,7 +205,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -216,7 +216,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): llm_usage=current_llm_usage, ) self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) final_answer += response + "\n" @@ -276,7 +276,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name="", tool_input="", thought="", @@ -291,7 +291,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): messages_ids=message_file_ids, ) self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) # update prompt tool