From 85285931e2d5783eb2d9fff90f349a8fe7362fdd Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 28 Mar 2024 20:04:31 +0800 Subject: [PATCH] feat: add agent tool invoke meta --- api/core/agent/base_agent_runner.py | 19 +++++- api/core/agent/cot_agent_runner.py | 45 ++++++++++---- api/core/agent/fc_agent_runner.py | 19 ++++-- .../workflow_tool_callback_handler.py | 5 ++ api/core/tools/entities/tool_entities.py | 29 ++++++++- api/core/tools/errors.py | 8 ++- api/core/tools/tool_engine.py | 59 +++++++++++++++++-- api/core/workflow/nodes/tool/tool_node.py | 12 +++- .../versions/c3311b089690_add_tool_meta.py | 31 ++++++++++ api/models/model.py | 11 ++++ api/models/tools.py | 6 +- 11 files changed, 209 insertions(+), 35 deletions(-) create mode 100644 api/core/callback_handler/workflow_tool_callback_handler.py create mode 100644 api/migrations/versions/c3311b089690_add_tool_meta.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index eb3e5ac193..8955ad2d1a 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -279,6 +279,7 @@ class BaseAgentRunner(AppRunner): thought='', tool=tool_name, tool_labels_str='{}', + tool_meta_str='{}', tool_input=tool_input, message=message, message_token=0, @@ -313,7 +314,8 @@ class BaseAgentRunner(AppRunner): tool_name: str, tool_input: Union[str, dict], thought: str, - observation: str, + observation: Union[str, str], + tool_invoke_meta: Union[str, dict], answer: str, messages_ids: list[str], llm_usage: LLMUsage = None) -> MessageAgentThought: @@ -340,6 +342,12 @@ class BaseAgentRunner(AppRunner): agent_thought.tool_input = tool_input if observation is not None: + if isinstance(observation, dict): + try: + observation = json.dumps(observation, ensure_ascii=False) + except Exception as e: + observation = json.dumps(observation) + agent_thought.observation = observation if answer is not None: @@ -373,6 +381,15 @@ class BaseAgentRunner(AppRunner): agent_thought.tool_labels_str = json.dumps(labels) + if tool_invoke_meta is not None: + if isinstance(tool_invoke_meta, dict): + try: + tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) + except Exception as e: + tool_invoke_meta = json.dumps(tool_invoke_meta) + + agent_thought.tool_meta_str = tool_invoke_meta + db.session.commit() db.session.close() diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index c8191552ee..d57f15638c 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine from models.model import Conversation, Message @@ -215,7 +216,10 @@ class CotAgentRunner(BaseAgentRunner): self.save_agent_thought(agent_thought=agent_thought, tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input=scratchpad.action.action_input if scratchpad.action else '', + tool_input={ + scratchpad.action.action_name: scratchpad.action.action_input + } if scratchpad.action else '', + tool_invoke_meta={}, thought=scratchpad.thought, observation='', answer=scratchpad.agent_response, @@ -248,13 +252,20 @@ class CotAgentRunner(BaseAgentRunner): tool_instance = tool_instances.get(tool_call_name) if not tool_instance: answer = f"there is not a tool named {tool_call_name}" - self.save_agent_thought(agent_thought=agent_thought, - tool_name='', - tool_input='', - thought=None, - observation=answer, - answer=answer, - messages_ids=[]) + self.save_agent_thought( + agent_thought=agent_thought, + tool_name='', + tool_input='', + tool_invoke_meta=ToolInvokeMeta.error_instance( + f"there is not a tool named {tool_call_name}" + ).to_dict(), + thought=None, + observation={ + tool_call_name: answer + }, + answer=answer, + messages_ids=[] + ) self.queue_manager.publish(QueueAgentThoughtEvent( agent_thought_id=agent_thought.id ), PublishFrom.APPLICATION_MANAGER) @@ -266,7 +277,7 @@ class CotAgentRunner(BaseAgentRunner): pass # invoke tool - tool_invoke_response, message_files = ToolEngine.agent_invoke( + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( tool=tool_instance, tool_parameters=tool_call_args, user_id=self.user_id, @@ -308,9 +319,16 @@ class CotAgentRunner(BaseAgentRunner): self.save_agent_thought( agent_thought=agent_thought, tool_name=tool_call_name, - tool_input=tool_call_args, + tool_input={ + tool_call_name: tool_call_args + }, + tool_invoke_meta={ + tool_call_name: tool_invoke_meta.to_dict() + }, thought=None, - observation=observation, + observation={ + tool_call_name: observation + }, answer=scratchpad.agent_response, messages_ids=message_file_ids, ) @@ -341,9 +359,10 @@ class CotAgentRunner(BaseAgentRunner): self.save_agent_thought( agent_thought=agent_thought, tool_name='', - tool_input='', + tool_input={}, + tool_invoke_meta={}, thought=final_answer, - observation='', + observation={}, answer=final_answer, messages_ids=[] ) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 533aff46fd..e66500d327 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine from models.model import Conversation, Message, MessageAgentThought @@ -226,6 +227,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, + tool_invoke_meta=None, observation=None, answer=response, messages_ids=[], @@ -251,11 +253,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, - "tool_response": f"there is not a tool named {tool_call_name}" + "tool_response": f"there is not a tool named {tool_call_name}", + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() } else: # invoke tool - tool_invoke_response, message_files = ToolEngine.agent_invoke( + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( tool=tool_instance, tool_parameters=tool_call_args, user_id=self.user_id, @@ -276,11 +279,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): # add message file ids message_file_ids.append(message_file.id) - observation = tool_invoke_response tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, - "tool_response": observation + "tool_response": tool_invoke_response, + "meta": tool_invoke_meta.to_dict() } tool_responses.append(tool_response) @@ -300,10 +303,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_name=None, tool_input=None, thought=None, - observation=json.dumps({ + tool_invoke_meta={ + tool_response['tool_call_name']: tool_response['meta'] + for tool_response in tool_responses + }, + observation={ tool_response['tool_call_name']: tool_response['tool_response'] for tool_response in tool_responses - }), + }, answer=None, messages_ids=message_file_ids ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py new file mode 100644 index 0000000000..84bab7e1a3 --- /dev/null +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -0,0 +1,5 @@ +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler + + +class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): + """Callback Handler that prints to std out.""" \ No newline at end of file diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 437f871864..1c0f476f7b 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -326,4 +326,31 @@ class ModelToolProviderConfiguration(BaseModel): """ provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") - label: I18nObject = Field(..., description="The label of the model tool") \ No newline at end of file + label: I18nObject = Field(..., description="The label of the model tool") + +class ToolInvokeMeta(BaseModel): + """ + Tool invoke meta + """ + time_cost: float = Field(..., description="The time cost of the tool invoke") + error: Optional[str] = None + + @classmethod + def empty(cls) -> 'ToolInvokeMeta': + """ + Get an empty instance of ToolInvokeMeta + """ + return cls(time_cost=0.0, error=None) + + @classmethod + def error_instance(cls, error: str) -> 'ToolInvokeMeta': + """ + Get an instance of ToolInvokeMeta with error + """ + return cls(time_cost=0.0, error=error) + + def to_dict(self) -> dict: + return { + 'time_cost': self.time_cost, + 'error': self.error, + } \ No newline at end of file diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index d1acb073ac..9fd8322db1 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -1,3 +1,6 @@ +from core.tools.entities.tool_entities import ToolInvokeMeta + + class ToolProviderNotFoundError(ValueError): pass @@ -17,4 +20,7 @@ class ToolInvokeError(ValueError): pass class ToolApiSchemaError(ValueError): - pass \ No newline at end of file + pass + +class ToolEngineInvokeError(Exception): + meta: ToolInvokeMeta \ No newline at end of file diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 7d836b4695..33bfafb423 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,10 +1,13 @@ +from datetime import datetime, timezone from typing import Union from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter from core.tools.errors import ( + ToolEngineInvokeError, ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, @@ -26,7 +29,7 @@ class ToolEngine: def agent_invoke(tool: Tool, tool_parameters: Union[str, dict], user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler) \ - -> tuple[str, list[tuple[MessageFile, bool]]]: + -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -52,7 +55,10 @@ class ToolEngine: tool_inputs=tool_parameters ) - response = tool.invoke(user_id, tool_parameters) + try: + meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) + except ToolEngineInvokeError as e: + meta = e.meta response = ToolFileMessageTransformer.transform_tool_invoke_messages( messages=response, @@ -81,7 +87,7 @@ class ToolEngine: ) # transform tool invoke message to get LLM friendly message - return plain_text, message_files + return plain_text, message_files, meta except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) @@ -102,15 +108,56 @@ class ToolEngine: error_response = f"unknown error: {e}" agent_tool_callback.on_tool_error(e) - return error_response, [] + return error_response, [], meta @staticmethod def workflow_invoke(tool: Tool, tool_parameters: dict, - user_id: str, workflow_id: str) -> dict: + user_id: str, workflow_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler) \ + -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. """ + try: + # hit the callback handler + workflow_tool_callback.on_tool_start( + tool_name=tool.identity.name, + tool_inputs=tool_parameters + ) + response = tool.invoke(user_id, tool_parameters) + + # hit the callback handler + workflow_tool_callback.on_tool_end( + tool_name=tool.identity.name, + tool_inputs=tool_parameters, + tool_outputs=response + ) + + return response + except Exception as e: + workflow_tool_callback.on_tool_error(e) + raise e + + @staticmethod + def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ + -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: + """ + Invoke the tool with the given arguments. + """ + started_at = datetime.now(timezone.utc) + meta = ToolInvokeMeta(time_cost=0.0, error=None) + try: + response = tool.invoke(user_id, tool_parameters) + except Exception as e: + meta.error = str(e) + raise ToolEngineInvokeError(meta=meta) + finally: + ended_at = datetime.now(timezone.utc) + meta.time_cost = (ended_at - started_at).total_seconds() + + return meta, response + @staticmethod def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 4f5b3332ae..003a259243 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,8 +1,10 @@ from os import path from typing import cast +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunResult, NodeType @@ -30,7 +32,7 @@ class ToolNode(BaseNode): parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: - tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) + tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -39,7 +41,13 @@ class ToolNode(BaseNode): ) try: - messages = tool_runtime.invoke(self.user_id, parameters) + messages = ToolEngine.workflow_invoke( + tool=tool_runtime, + tool_parameters=parameters, + user_id=self.user_id, + workflow_id=self.workflow_id, + workflow_tool_callback=DifyWorkflowCallbackHandler() + ) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, diff --git a/api/migrations/versions/c3311b089690_add_tool_meta.py b/api/migrations/versions/c3311b089690_add_tool_meta.py new file mode 100644 index 0000000000..e075535b0d --- /dev/null +++ b/api/migrations/versions/c3311b089690_add_tool_meta.py @@ -0,0 +1,31 @@ +"""add tool meta + +Revision ID: c3311b089690 +Revises: e2eacc9a1b63 +Create Date: 2024-03-28 11:50:45.364875 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'c3311b089690' +down_revision = 'e2eacc9a1b63' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_meta_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('tool_meta_str') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 9914658272..fc0d53bcde 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1134,6 +1134,7 @@ class MessageAgentThought(db.Model): thought = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True) tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) tool_input = db.Column(db.Text, nullable=True) observation = db.Column(db.Text, nullable=True) # plugin_id = db.Column(UUID, nullable=True) ## for future design @@ -1171,6 +1172,16 @@ class MessageAgentThought(db.Model): return {} except Exception as e: return {} + + @property + def tool_meta(self) -> dict: + try: + if self.tool_meta_str: + return json.loads(self.tool_meta_str) + else: + return {} + except Exception as e: + return {} class DatasetRetrieverResource(db.Model): __tablename__ = 'dataset_retriever_resources' diff --git a/api/models/tools.py b/api/models/tools.py index 4bdf2503ce..414d055780 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -58,7 +58,7 @@ class PublishedAppTool(db.Model): description = db.Column(db.Text, nullable=False) # llm_description of the tool, for LLM llm_description = db.Column(db.Text, nullable=False) - # query decription, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field + # query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field query_description = db.Column(db.Text, nullable=False) # query name, the name of the query parameter query_name = db.Column(db.String(40), nullable=False) @@ -123,10 +123,6 @@ class ApiToolProvider(db.Model): def credentials(self) -> dict: return json.loads(self.credentials_str) - @property - def is_taned(self) -> bool: - return self.tenant_id is not None - @property def user(self) -> Account: return db.session.query(Account).filter(Account.id == self.user_id).first()