feat: add agent tool invoke meta

This commit is contained in:
Yeuoly 2024-03-28 20:04:31 +08:00
parent d7c4032917
commit 85285931e2
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
11 changed files with 209 additions and 35 deletions

View File

@ -279,6 +279,7 @@ class BaseAgentRunner(AppRunner):
thought='', thought='',
tool=tool_name, tool=tool_name,
tool_labels_str='{}', tool_labels_str='{}',
tool_meta_str='{}',
tool_input=tool_input, tool_input=tool_input,
message=message, message=message,
message_token=0, message_token=0,
@ -313,7 +314,8 @@ class BaseAgentRunner(AppRunner):
tool_name: str, tool_name: str,
tool_input: Union[str, dict], tool_input: Union[str, dict],
thought: str, thought: str,
observation: str, observation: Union[str, str],
tool_invoke_meta: Union[str, dict],
answer: str, answer: str,
messages_ids: list[str], messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought: llm_usage: LLMUsage = None) -> MessageAgentThought:
@ -340,6 +342,12 @@ class BaseAgentRunner(AppRunner):
agent_thought.tool_input = tool_input agent_thought.tool_input = tool_input
if observation is not None: if observation is not None:
if isinstance(observation, dict):
try:
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation agent_thought.observation = observation
if answer is not None: if answer is not None:
@ -373,6 +381,15 @@ class BaseAgentRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception as e:
tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit() db.session.commit()
db.session.close() db.session.close()

View File

@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message from models.model import Conversation, Message
@ -215,7 +216,10 @@ class CotAgentRunner(BaseAgentRunner):
self.save_agent_thought(agent_thought=agent_thought, self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '', tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '', tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else '',
tool_invoke_meta={},
thought=scratchpad.thought, thought=scratchpad.thought,
observation='', observation='',
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
@ -248,13 +252,20 @@ class CotAgentRunner(BaseAgentRunner):
tool_instance = tool_instances.get(tool_call_name) tool_instance = tool_instances.get(tool_call_name)
if not tool_instance: if not tool_instance:
answer = f"there is not a tool named {tool_call_name}" answer = f"there is not a tool named {tool_call_name}"
self.save_agent_thought(agent_thought=agent_thought, self.save_agent_thought(
tool_name='', agent_thought=agent_thought,
tool_input='', tool_name='',
thought=None, tool_input='',
observation=answer, tool_invoke_meta=ToolInvokeMeta.error_instance(
answer=answer, f"there is not a tool named {tool_call_name}"
messages_ids=[]) ).to_dict(),
thought=None,
observation={
tool_call_name: answer
},
answer=answer,
messages_ids=[]
)
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER) ), PublishFrom.APPLICATION_MANAGER)
@ -266,7 +277,7 @@ class CotAgentRunner(BaseAgentRunner):
pass pass
# invoke tool # invoke tool
tool_invoke_response, message_files = ToolEngine.agent_invoke( tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance, tool=tool_instance,
tool_parameters=tool_call_args, tool_parameters=tool_call_args,
user_id=self.user_id, user_id=self.user_id,
@ -308,9 +319,16 @@ class CotAgentRunner(BaseAgentRunner):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=tool_call_name, tool_name=tool_call_name,
tool_input=tool_call_args, tool_input={
tool_call_name: tool_call_args
},
tool_invoke_meta={
tool_call_name: tool_invoke_meta.to_dict()
},
thought=None, thought=None,
observation=observation, observation={
tool_call_name: observation
},
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=message_file_ids, messages_ids=message_file_ids,
) )
@ -341,9 +359,10 @@ class CotAgentRunner(BaseAgentRunner):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name='', tool_name='',
tool_input='', tool_input={},
tool_invoke_meta={},
thought=final_answer, thought=final_answer,
observation='', observation={},
answer=final_answer, answer=final_answer,
messages_ids=[] messages_ids=[]
) )

View File

@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message, MessageAgentThought from models.model import Conversation, Message, MessageAgentThought
@ -226,6 +227,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_name=tool_call_names, tool_name=tool_call_names,
tool_input=tool_call_inputs, tool_input=tool_call_inputs,
thought=response, thought=response,
tool_invoke_meta=None,
observation=None, observation=None,
answer=response, answer=response,
messages_ids=[], messages_ids=[],
@ -251,11 +253,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_response = { tool_response = {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}" "tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
} }
else: else:
# invoke tool # invoke tool
tool_invoke_response, message_files = ToolEngine.agent_invoke( tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance, tool=tool_instance,
tool_parameters=tool_call_args, tool_parameters=tool_call_args,
user_id=self.user_id, user_id=self.user_id,
@ -276,11 +279,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# add message file ids # add message file ids
message_file_ids.append(message_file.id) message_file_ids.append(message_file.id)
observation = tool_invoke_response
tool_response = { tool_response = {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": observation "tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict()
} }
tool_responses.append(tool_response) tool_responses.append(tool_response)
@ -300,10 +303,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_name=None, tool_name=None,
tool_input=None, tool_input=None,
thought=None, thought=None,
observation=json.dumps({ tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
},
observation={
tool_response['tool_call_name']: tool_response['tool_response'] tool_response['tool_call_name']: tool_response['tool_response']
for tool_response in tool_responses for tool_response in tool_responses
}), },
answer=None, answer=None,
messages_ids=message_file_ids messages_ids=message_file_ids
) )

View File

@ -0,0 +1,5 @@
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""

View File

@ -326,4 +326,31 @@ class ModelToolProviderConfiguration(BaseModel):
""" """
provider: str = Field(..., description="The provider of the model tool") provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool") label: I18nObject = Field(..., description="The label of the model tool")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
@classmethod
def empty(cls) -> 'ToolInvokeMeta':
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None)
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error)
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
}

View File

@ -1,3 +1,6 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
class ToolProviderNotFoundError(ValueError): class ToolProviderNotFoundError(ValueError):
pass pass
@ -17,4 +20,7 @@ class ToolInvokeError(ValueError):
pass pass
class ToolApiSchemaError(ValueError): class ToolApiSchemaError(ValueError):
pass pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@ -1,10 +1,13 @@
from datetime import datetime, timezone
from typing import Union from typing import Union
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod from core.file.file_obj import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
from core.tools.errors import ( from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError, ToolInvokeError,
ToolNotFoundError, ToolNotFoundError,
ToolNotSupportedError, ToolNotSupportedError,
@ -26,7 +29,7 @@ class ToolEngine:
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict], def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler) \ agent_tool_callback: DifyAgentCallbackHandler) \
-> tuple[str, list[tuple[MessageFile, bool]]]: -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
""" """
Agent invokes the tool with the given arguments. Agent invokes the tool with the given arguments.
""" """
@ -52,7 +55,10 @@ class ToolEngine:
tool_inputs=tool_parameters tool_inputs=tool_parameters
) )
response = tool.invoke(user_id, tool_parameters) try:
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
except ToolEngineInvokeError as e:
meta = e.meta
response = ToolFileMessageTransformer.transform_tool_invoke_messages( response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=response, messages=response,
@ -81,7 +87,7 @@ class ToolEngine:
) )
# transform tool invoke message to get LLM friendly message # transform tool invoke message to get LLM friendly message
return plain_text, message_files return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e: except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials" error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e) agent_tool_callback.on_tool_error(e)
@ -102,15 +108,56 @@ class ToolEngine:
error_response = f"unknown error: {e}" error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e) agent_tool_callback.on_tool_error(e)
return error_response, [] return error_response, [], meta
@staticmethod @staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict, def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str) -> dict: user_id: str, workflow_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler) \
-> list[ToolInvokeMessage]:
""" """
Workflow invokes the tool with the given arguments. Workflow invokes the tool with the given arguments.
""" """
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(
tool_name=tool.identity.name,
tool_inputs=tool_parameters
)
response = tool.invoke(user_id, tool_parameters)
# hit the callback handler
workflow_tool_callback.on_tool_end(
tool_name=tool.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response
)
return response
except Exception as e:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta(time_cost=0.0, error=None)
try:
response = tool.invoke(user_id, tool_parameters)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta=meta)
finally:
ended_at = datetime.now(timezone.utc)
meta.time_cost = (ended_at - started_at).total_seconds()
return meta, response
@staticmethod @staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
""" """

View File

@ -1,8 +1,10 @@
from os import path from os import path
from typing import cast from typing import cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
@ -30,7 +32,7 @@ class ToolNode(BaseNode):
parameters = self._generate_parameters(variable_pool, node_data) parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime # get tool runtime
try: try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
except Exception as e: except Exception as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -39,7 +41,13 @@ class ToolNode(BaseNode):
) )
try: try:
messages = tool_runtime.invoke(self.user_id, parameters) messages = ToolEngine.workflow_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler()
)
except Exception as e: except Exception as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -0,0 +1,31 @@
"""add tool meta
Revision ID: c3311b089690
Revises: e2eacc9a1b63
Create Date: 2024-03-28 11:50:45.364875
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'c3311b089690'
down_revision = 'e2eacc9a1b63'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_meta_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.drop_column('tool_meta_str')
# ### end Alembic commands ###

View File

@ -1134,6 +1134,7 @@ class MessageAgentThought(db.Model):
thought = db.Column(db.Text, nullable=True) thought = db.Column(db.Text, nullable=True)
tool = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True)
tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = db.Column(db.Text, nullable=True) tool_input = db.Column(db.Text, nullable=True)
observation = db.Column(db.Text, nullable=True) observation = db.Column(db.Text, nullable=True)
# plugin_id = db.Column(UUID, nullable=True) ## for future design # plugin_id = db.Column(UUID, nullable=True) ## for future design
@ -1171,6 +1172,16 @@ class MessageAgentThought(db.Model):
return {} return {}
except Exception as e: except Exception as e:
return {} return {}
@property
def tool_meta(self) -> dict:
try:
if self.tool_meta_str:
return json.loads(self.tool_meta_str)
else:
return {}
except Exception as e:
return {}
class DatasetRetrieverResource(db.Model): class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources' __tablename__ = 'dataset_retriever_resources'

View File

@ -58,7 +58,7 @@ class PublishedAppTool(db.Model):
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM # llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False) llm_description = db.Column(db.Text, nullable=False)
# query decription, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field # query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
query_description = db.Column(db.Text, nullable=False) query_description = db.Column(db.Text, nullable=False)
# query name, the name of the query parameter # query name, the name of the query parameter
query_name = db.Column(db.String(40), nullable=False) query_name = db.Column(db.String(40), nullable=False)
@ -123,10 +123,6 @@ class ApiToolProvider(db.Model):
def credentials(self) -> dict: def credentials(self) -> dict:
return json.loads(self.credentials_str) return json.loads(self.credentials_str)
@property
def is_taned(self) -> bool:
return self.tenant_id is not None
@property @property
def user(self) -> Account: def user(self) -> Account:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()