mirror of https://github.com/langgenius/dify.git
feat: add agent tool invoke meta
This commit is contained in:
parent
d7c4032917
commit
85285931e2
|
|
@ -279,6 +279,7 @@ class BaseAgentRunner(AppRunner):
|
|||
thought='',
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_meta_str='{}',
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
|
|
@ -313,7 +314,8 @@ class BaseAgentRunner(AppRunner):
|
|||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: str,
|
||||
observation: Union[str, str],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
|
|
@ -340,6 +342,12 @@ class BaseAgentRunner(AppRunner):
|
|||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
|
|
@ -373,6 +381,15 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
try:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
|
@ -215,7 +216,10 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else '',
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=scratchpad.agent_response,
|
||||
|
|
@ -248,13 +252,20 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=None,
|
||||
observation=answer,
|
||||
answer=answer,
|
||||
messages_ids=[])
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
tool_invoke_meta=ToolInvokeMeta.error_instance(
|
||||
f"there is not a tool named {tool_call_name}"
|
||||
).to_dict(),
|
||||
thought=None,
|
||||
observation={
|
||||
tool_call_name: answer
|
||||
},
|
||||
answer=answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
@ -266,7 +277,7 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files = ToolEngine.agent_invoke(
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
|
|
@ -308,9 +319,16 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_name,
|
||||
tool_input=tool_call_args,
|
||||
tool_input={
|
||||
tool_call_name: tool_call_args
|
||||
},
|
||||
tool_invoke_meta={
|
||||
tool_call_name: tool_invoke_meta.to_dict()
|
||||
},
|
||||
thought=None,
|
||||
observation=observation,
|
||||
observation={
|
||||
tool_call_name: observation
|
||||
},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
|
|
@ -341,9 +359,10 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation='',
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
||||
|
|
@ -226,6 +227,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
tool_invoke_meta=None,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
|
|
@ -251,11 +253,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}"
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files = ToolEngine.agent_invoke(
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
|
|
@ -276,11 +279,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
observation = tool_invoke_response
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": observation
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
|
|
@ -300,10 +303,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=json.dumps({
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
for tool_response in tool_responses
|
||||
}),
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
|
||||
|
||||
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
|
@ -326,4 +326,31 @@ class ModelToolProviderConfiguration(BaseModel):
|
|||
"""
|
||||
provider: str = Field(..., description="The provider of the model tool")
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
"""
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> 'ToolInvokeMeta':
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None)
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'time_cost': self.time_cost,
|
||||
'error': self.error,
|
||||
}
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
|
@ -17,4 +20,7 @@ class ToolInvokeError(ValueError):
|
|||
pass
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
pass
|
||||
pass
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
|
|
@ -26,7 +29,7 @@ class ToolEngine:
|
|||
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
|
||||
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler) \
|
||||
-> tuple[str, list[tuple[MessageFile, bool]]]:
|
||||
-> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
|
|
@ -52,7 +55,10 @@ class ToolEngine:
|
|||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
try:
|
||||
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.meta
|
||||
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=response,
|
||||
|
|
@ -81,7 +87,7 @@ class ToolEngine:
|
|||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files
|
||||
return plain_text, message_files, meta
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
|
@ -102,15 +108,56 @@ class ToolEngine:
|
|||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, []
|
||||
return error_response, [], meta
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str) -> dict:
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler) \
|
||||
-> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_start(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
|
||||
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(timezone.utc)
|
||||
meta = ToolInvokeMeta(time_cost=0.0, error=None)
|
||||
try:
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta=meta)
|
||||
finally:
|
||||
ended_at = datetime.now(timezone.utc)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
|
||||
return meta, response
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
|
|
@ -30,7 +32,7 @@ class ToolNode(BaseNode):
|
|||
parameters = self._generate_parameters(variable_pool, node_data)
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None)
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -39,7 +41,13 @@ class ToolNode(BaseNode):
|
|||
)
|
||||
|
||||
try:
|
||||
messages = tool_runtime.invoke(self.user_id, parameters)
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler()
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
"""add tool meta
|
||||
|
||||
Revision ID: c3311b089690
|
||||
Revises: e2eacc9a1b63
|
||||
Create Date: 2024-03-28 11:50:45.364875
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c3311b089690'
|
||||
down_revision = 'e2eacc9a1b63'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_meta_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.drop_column('tool_meta_str')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1134,6 +1134,7 @@ class MessageAgentThought(db.Model):
|
|||
thought = db.Column(db.Text, nullable=True)
|
||||
tool = db.Column(db.Text, nullable=True)
|
||||
tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_input = db.Column(db.Text, nullable=True)
|
||||
observation = db.Column(db.Text, nullable=True)
|
||||
# plugin_id = db.Column(UUID, nullable=True) ## for future design
|
||||
|
|
@ -1171,6 +1172,16 @@ class MessageAgentThought(db.Model):
|
|||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def tool_meta(self) -> dict:
|
||||
try:
|
||||
if self.tool_meta_str:
|
||||
return json.loads(self.tool_meta_str)
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
class DatasetRetrieverResource(db.Model):
|
||||
__tablename__ = 'dataset_retriever_resources'
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class PublishedAppTool(db.Model):
|
|||
description = db.Column(db.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = db.Column(db.Text, nullable=False)
|
||||
# query decription, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
|
||||
# query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
|
||||
query_description = db.Column(db.Text, nullable=False)
|
||||
# query name, the name of the query parameter
|
||||
query_name = db.Column(db.String(40), nullable=False)
|
||||
|
|
@ -123,10 +123,6 @@ class ApiToolProvider(db.Model):
|
|||
def credentials(self) -> dict:
|
||||
return json.loads(self.credentials_str)
|
||||
|
||||
@property
|
||||
def is_taned(self) -> bool:
|
||||
return self.tenant_id is not None
|
||||
|
||||
@property
|
||||
def user(self) -> Account:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
|
|
|||
Loading…
Reference in New Issue