mirror of https://github.com/langgenius/dify.git
feat: add inputs parameter to tool invocation methods for enhanced flexibility
This commit is contained in:
parent
a47276ac24
commit
a87c7b0064
|
|
@ -23,7 +23,7 @@ from core.model_runtime.entities import (
|
|||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta, ToolProviderType
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
|
|
@ -234,6 +234,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
inputs_to_pass = (
|
||||
kwargs.get("inputs")
|
||||
if tool_instance.tool_provider_type() == ToolProviderType.MCP
|
||||
else None
|
||||
)
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
|
|
@ -247,6 +252,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
inputs=inputs_to_pass,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
|
|
@ -49,6 +50,7 @@ class Tool(ABC):
|
|||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> Generator[ToolInvokeMessage]:
|
||||
if self.runtime and self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
|
@ -56,13 +58,24 @@ class Tool(ABC):
|
|||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
# Construct the call parameter and pass in inputs only when the _invoke of the subclass accepts inputs
|
||||
invoke_kwargs = {
|
||||
"user_id": user_id,
|
||||
"tool_parameters": tool_parameters,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
if inputs is not None:
|
||||
try:
|
||||
sig = inspect.signature(self._invoke)
|
||||
if "inputs" in sig.parameters:
|
||||
invoke_kwargs["inputs"] = inputs
|
||||
except Exception:
|
||||
# fallback: Do not pass inputs if reflection fails
|
||||
pass
|
||||
|
||||
result = self._invoke(**invoke_kwargs)
|
||||
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,13 @@ class MCPTool(Tool):
|
|||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
# If inputs is provided, flatten its key/value pairs into the tool_parameters.
|
||||
# This is a conservative merge: only add a key from inputs if it does not already exist in tool_parameters.
|
||||
for key, value in inputs.items():
|
||||
if key not in tool_parameters:
|
||||
tool_parameters[key] = value
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ class ToolEngine:
|
|||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
|
|
@ -79,7 +80,7 @@ class ToolEngine:
|
|||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id, inputs)
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
|
|
@ -197,6 +198,7 @@ class ToolEngine:
|
|||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
|
|
@ -214,7 +216,7 @@ class ToolEngine:
|
|||
},
|
||||
)
|
||||
try:
|
||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
|
||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id, inputs)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
|
|
|
|||
Loading…
Reference in New Issue