This commit is contained in:
Zino 2025-12-29 16:33:41 +08:00 committed by GitHub
commit 54933a020d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 46 additions and 18 deletions

View File

@ -23,7 +23,7 @@ from core.model_runtime.entities import (
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta, ToolProviderType
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Message from models.model import Message
@ -234,6 +234,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
} }
else: else:
inputs_to_pass = (
kwargs.get("inputs") if tool_instance.tool_provider_type() == ToolProviderType.MCP else None
)
# invoke tool # invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance, tool=tool_instance,
@ -247,6 +250,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
app_id=self.application_generate_entity.app_config.app_id, app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id, message_id=self.message.id,
conversation_id=self.conversation.id, conversation_id=self.conversation.id,
inputs=inputs_to_pass,
) )
# publish files # publish files
for message_file_id in message_files: for message_file_id in message_files:

View File

@ -180,7 +180,9 @@ class MCPClientWithAuthRetry(MCPClient):
""" """
return self._execute_with_retry(super().list_tools) return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: def invoke_tool(
self, tool_name: str, tool_args: dict[str, Any], _meta: dict[str, Any] | None = None
) -> CallToolResult:
""" """
Invoke a tool on the MCP server with auth retry. Invoke a tool on the MCP server with auth retry.
@ -194,4 +196,4 @@ class MCPClientWithAuthRetry(MCPClient):
Raises: Raises:
MCPAuthError: If authentication fails after retries MCPAuthError: If authentication fails after retries
""" """
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args) return self._execute_with_retry(super().invoke_tool, tool_name, tool_args, _meta)

View File

@ -96,11 +96,13 @@ class MCPClient:
response = self._session.list_tools() response = self._session.list_tools()
return response.tools return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: def invoke_tool(
self, tool_name: str, tool_args: dict[str, Any], _meta: dict[str, Any] | None = None
) -> CallToolResult:
"""Call a tool""" """Call a tool"""
if not self._session: if not self._session:
raise ValueError("Session not initialized.") raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args) return self._session.call_tool(tool_name, tool_args, _meta=_meta)
def cleanup(self): def cleanup(self):
"""Clean up resources""" """Clean up resources"""

View File

@ -248,6 +248,7 @@ class ClientSession(
self, self,
name: str, name: str,
arguments: dict[str, Any] | None = None, arguments: dict[str, Any] | None = None,
_meta: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
) -> types.CallToolResult: ) -> types.CallToolResult:
"""Send a tools/call request.""" """Send a tools/call request."""
@ -256,7 +257,7 @@ class ClientSession(
types.ClientRequest( types.ClientRequest(
types.CallToolRequest( types.CallToolRequest(
method="tools/call", method="tools/call",
params=types.CallToolRequestParams(name=name, arguments=arguments), params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
) )
), ),
types.CallToolResult, types.CallToolResult,

View File

@ -855,6 +855,7 @@ class CallToolRequestParams(RequestParams):
name: str name: str
arguments: dict[str, Any] | None = None arguments: dict[str, Any] | None = None
_meta: dict[str, Any] | None = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@ -1,3 +1,4 @@
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
@ -49,6 +50,7 @@ class Tool(ABC):
conversation_id: str | None = None, conversation_id: str | None = None,
app_id: str | None = None, app_id: str | None = None,
message_id: str | None = None, message_id: str | None = None,
inputs: dict[str, Any] | None = None,
) -> Generator[ToolInvokeMessage]: ) -> Generator[ToolInvokeMessage]:
if self.runtime and self.runtime.runtime_parameters: if self.runtime and self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters) tool_parameters.update(self.runtime.runtime_parameters)
@ -56,13 +58,24 @@ class Tool(ABC):
# try parse tool parameters into the correct type # try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters) tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke( # Construct the call parameter and pass in inputs only when the _invoke of the subclass accepts inputs
user_id=user_id, invoke_kwargs = {
tool_parameters=tool_parameters, "user_id": user_id,
conversation_id=conversation_id, "tool_parameters": tool_parameters,
app_id=app_id, "conversation_id": conversation_id,
message_id=message_id, "app_id": app_id,
) "message_id": message_id,
}
if inputs is not None:
try:
sig = inspect.signature(self._invoke)
if "inputs" in sig.parameters:
invoke_kwargs["inputs"] = inputs
except (ValueError, TypeError):
# fallback: Do not pass inputs if reflection fails
pass
result = self._invoke(**invoke_kwargs)
if isinstance(result, ToolInvokeMessage): if isinstance(result, ToolInvokeMessage):

View File

@ -55,8 +55,9 @@ class MCPTool(Tool):
conversation_id: str | None = None, conversation_id: str | None = None,
app_id: str | None = None, app_id: str | None = None,
message_id: str | None = None, message_id: str | None = None,
inputs: dict[str, Any] | None = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters) result = self.invoke_remote_mcp_tool(tool_parameters, _meta=inputs)
# handle dify tool output # handle dify tool output
for content in result.content: for content in result.content:
if isinstance(content, TextContent): if isinstance(content, TextContent):
@ -141,7 +142,7 @@ class MCPTool(Tool):
if value is not None and not (isinstance(value, str) and value.strip() == "") if value is not None and not (isinstance(value, str) and value.strip() == "")
} }
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult: def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any], _meta: dict[str, Any] | None) -> CallToolResult:
headers = self.headers.copy() if self.headers else {} headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters) tool_parameters = self._handle_none_parameter(tool_parameters)
@ -176,7 +177,9 @@ class MCPTool(Tool):
sse_read_timeout=self.sse_read_timeout, sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity, provider_entity=provider_entity,
) as mcp_client: ) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) return mcp_client.invoke_tool(
tool_name=self.entity.identity.name, tool_args=tool_parameters, _meta=_meta
)
except MCPConnectionError as e: except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e: except Exception as e:

View File

@ -55,6 +55,7 @@ class ToolEngine:
conversation_id: str | None = None, conversation_id: str | None = None,
app_id: str | None = None, app_id: str | None = None,
message_id: str | None = None, message_id: str | None = None,
inputs: dict[str, Any] | None = None,
) -> tuple[str, list[str], ToolInvokeMeta]: ) -> tuple[str, list[str], ToolInvokeMeta]:
""" """
Agent invokes the tool with the given arguments. Agent invokes the tool with the given arguments.
@ -79,7 +80,7 @@ class ToolEngine:
# hit the callback handler # hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id) messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id, inputs)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {} invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback( def message_callback(
@ -197,6 +198,7 @@ class ToolEngine:
conversation_id: str | None = None, conversation_id: str | None = None,
app_id: str | None = None, app_id: str | None = None,
message_id: str | None = None, message_id: str | None = None,
inputs: dict[str, Any] | None = None,
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
""" """
Invoke the tool with the given arguments. Invoke the tool with the given arguments.
@ -214,7 +216,7 @@ class ToolEngine:
}, },
) )
try: try:
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id) yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id, inputs)
except Exception as e: except Exception as e:
meta.error = str(e) meta.error = str(e)
raise ToolEngineInvokeError(meta) raise ToolEngineInvokeError(meta)