feat: add agent tool invoke meta

This commit is contained in:
Yeuoly 2024-03-28 20:04:31 +08:00
parent d7c4032917
commit 85285931e2
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
11 changed files with 209 additions and 35 deletions

View File

@ -279,6 +279,7 @@ class BaseAgentRunner(AppRunner):
thought='',
tool=tool_name,
tool_labels_str='{}',
tool_meta_str='{}',
tool_input=tool_input,
message=message,
message_token=0,
@ -313,7 +314,8 @@ class BaseAgentRunner(AppRunner):
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: str,
observation: Union[str, str],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
@ -340,6 +342,12 @@ class BaseAgentRunner(AppRunner):
agent_thought.tool_input = tool_input
if observation is not None:
if isinstance(observation, dict):
try:
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation
if answer is not None:
@ -373,6 +381,15 @@ class BaseAgentRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception as e:
tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()

View File

@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message
@ -215,7 +216,10 @@ class CotAgentRunner(BaseAgentRunner):
self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else '',
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
answer=scratchpad.agent_response,
@ -248,13 +252,20 @@ class CotAgentRunner(BaseAgentRunner):
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
self.save_agent_thought(agent_thought=agent_thought,
tool_name='',
tool_input='',
thought=None,
observation=answer,
answer=answer,
messages_ids=[])
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_input='',
tool_invoke_meta=ToolInvokeMeta.error_instance(
f"there is not a tool named {tool_call_name}"
).to_dict(),
thought=None,
observation={
tool_call_name: answer
},
answer=answer,
messages_ids=[]
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
@ -266,7 +277,7 @@ class CotAgentRunner(BaseAgentRunner):
pass
# invoke tool
tool_invoke_response, message_files = ToolEngine.agent_invoke(
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
@ -308,9 +319,16 @@ class CotAgentRunner(BaseAgentRunner):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=tool_call_name,
tool_input=tool_call_args,
tool_input={
tool_call_name: tool_call_args
},
tool_invoke_meta={
tool_call_name: tool_invoke_meta.to_dict()
},
thought=None,
observation=observation,
observation={
tool_call_name: observation
},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
@ -341,9 +359,10 @@ class CotAgentRunner(BaseAgentRunner):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_input='',
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation='',
observation={},
answer=final_answer,
messages_ids=[]
)

View File

@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message, MessageAgentThought
@ -226,6 +227,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
tool_invoke_meta=None,
observation=None,
answer=response,
messages_ids=[],
@ -251,11 +253,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}"
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
}
else:
# invoke tool
tool_invoke_response, message_files = ToolEngine.agent_invoke(
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
@ -276,11 +279,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# add message file ids
message_file_ids.append(message_file.id)
observation = tool_invoke_response
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": observation
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict()
}
tool_responses.append(tool_response)
@ -300,10 +303,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_name=None,
tool_input=None,
thought=None,
observation=json.dumps({
tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
},
observation={
tool_response['tool_call_name']: tool_response['tool_response']
for tool_response in tool_responses
}),
},
answer=None,
messages_ids=message_file_ids
)

View File

@ -0,0 +1,5 @@
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""

View File

@ -326,4 +326,31 @@ class ModelToolProviderConfiguration(BaseModel):
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
@classmethod
def empty(cls) -> 'ToolInvokeMeta':
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None)
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error)
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
}

View File

@ -1,3 +1,6 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
class ToolProviderNotFoundError(ValueError):
pass
@ -17,4 +20,7 @@ class ToolInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
pass
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@ -1,10 +1,13 @@
from datetime import datetime, timezone
from typing import Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
@ -26,7 +29,7 @@ class ToolEngine:
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler) \
-> tuple[str, list[tuple[MessageFile, bool]]]:
-> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
@ -52,7 +55,10 @@ class ToolEngine:
tool_inputs=tool_parameters
)
response = tool.invoke(user_id, tool_parameters)
try:
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
except ToolEngineInvokeError as e:
meta = e.meta
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=response,
@ -81,7 +87,7 @@ class ToolEngine:
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
@ -102,15 +108,56 @@ class ToolEngine:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, []
return error_response, [], meta
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str) -> dict:
user_id: str, workflow_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler) \
-> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
"""
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(
tool_name=tool.identity.name,
tool_inputs=tool_parameters
)
response = tool.invoke(user_id, tool_parameters)
# hit the callback handler
workflow_tool_callback.on_tool_end(
tool_name=tool.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response
)
return response
except Exception as e:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta(time_cost=0.0, error=None)
try:
response = tool.invoke(user_id, tool_parameters)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta=meta)
finally:
ended_at = datetime.now(timezone.utc)
meta.time_cost = (ended_at - started_at).total_seconds()
return meta, response
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""

View File

@ -1,8 +1,10 @@
from os import path
from typing import cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunResult, NodeType
@ -30,7 +32,7 @@ class ToolNode(BaseNode):
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None)
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -39,7 +41,13 @@ class ToolNode(BaseNode):
)
try:
messages = tool_runtime.invoke(self.user_id, parameters)
messages = ToolEngine.workflow_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler()
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -0,0 +1,31 @@
"""add tool meta
Revision ID: c3311b089690
Revises: e2eacc9a1b63
Create Date: 2024-03-28 11:50:45.364875
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'c3311b089690'
down_revision = 'e2eacc9a1b63'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_meta_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.drop_column('tool_meta_str')
# ### end Alembic commands ###

View File

@ -1134,6 +1134,7 @@ class MessageAgentThought(db.Model):
thought = db.Column(db.Text, nullable=True)
tool = db.Column(db.Text, nullable=True)
tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = db.Column(db.Text, nullable=True)
observation = db.Column(db.Text, nullable=True)
# plugin_id = db.Column(UUID, nullable=True) ## for future design
@ -1171,6 +1172,16 @@ class MessageAgentThought(db.Model):
return {}
except Exception as e:
return {}
@property
def tool_meta(self) -> dict:
try:
if self.tool_meta_str:
return json.loads(self.tool_meta_str)
else:
return {}
except Exception as e:
return {}
class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources'

View File

@ -58,7 +58,7 @@ class PublishedAppTool(db.Model):
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False)
# query decription, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
# query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
query_description = db.Column(db.Text, nullable=False)
# query name, the name of the query parameter
query_name = db.Column(db.String(40), nullable=False)
@ -123,10 +123,6 @@ class ApiToolProvider(db.Model):
def credentials(self) -> dict:
return json.loads(self.credentials_str)
@property
def is_taned(self) -> bool:
return self.tenant_id is not None
@property
def user(self) -> Account:
return db.session.query(Account).filter(Account.id == self.user_id).first()