mirror of https://github.com/langgenius/dify.git
Merge a9035301a7 into 2c919efa69
This commit is contained in:
commit
54933a020d
|
|
@ -23,7 +23,7 @@ from core.model_runtime.entities import (
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta, ToolProviderType
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
|
||||||
|
|
@ -234,6 +234,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
inputs_to_pass = (
|
||||||
|
kwargs.get("inputs") if tool_instance.tool_provider_type() == ToolProviderType.MCP else None
|
||||||
|
)
|
||||||
# invoke tool
|
# invoke tool
|
||||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||||
tool=tool_instance,
|
tool=tool_instance,
|
||||||
|
|
@ -247,6 +250,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
app_id=self.application_generate_entity.app_config.app_id,
|
app_id=self.application_generate_entity.app_config.app_id,
|
||||||
message_id=self.message.id,
|
message_id=self.message.id,
|
||||||
conversation_id=self.conversation.id,
|
conversation_id=self.conversation.id,
|
||||||
|
inputs=inputs_to_pass,
|
||||||
)
|
)
|
||||||
# publish files
|
# publish files
|
||||||
for message_file_id in message_files:
|
for message_file_id in message_files:
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,9 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||||
"""
|
"""
|
||||||
return self._execute_with_retry(super().list_tools)
|
return self._execute_with_retry(super().list_tools)
|
||||||
|
|
||||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
def invoke_tool(
|
||||||
|
self, tool_name: str, tool_args: dict[str, Any], _meta: dict[str, Any] | None = None
|
||||||
|
) -> CallToolResult:
|
||||||
"""
|
"""
|
||||||
Invoke a tool on the MCP server with auth retry.
|
Invoke a tool on the MCP server with auth retry.
|
||||||
|
|
||||||
|
|
@ -194,4 +196,4 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||||
Raises:
|
Raises:
|
||||||
MCPAuthError: If authentication fails after retries
|
MCPAuthError: If authentication fails after retries
|
||||||
"""
|
"""
|
||||||
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)
|
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args, _meta)
|
||||||
|
|
|
||||||
|
|
@ -96,11 +96,13 @@ class MCPClient:
|
||||||
response = self._session.list_tools()
|
response = self._session.list_tools()
|
||||||
return response.tools
|
return response.tools
|
||||||
|
|
||||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
def invoke_tool(
|
||||||
|
self, tool_name: str, tool_args: dict[str, Any], _meta: dict[str, Any] | None = None
|
||||||
|
) -> CallToolResult:
|
||||||
"""Call a tool"""
|
"""Call a tool"""
|
||||||
if not self._session:
|
if not self._session:
|
||||||
raise ValueError("Session not initialized.")
|
raise ValueError("Session not initialized.")
|
||||||
return self._session.call_tool(tool_name, tool_args)
|
return self._session.call_tool(tool_name, tool_args, _meta=_meta)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Clean up resources"""
|
"""Clean up resources"""
|
||||||
|
|
|
||||||
|
|
@ -248,6 +248,7 @@ class ClientSession(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
arguments: dict[str, Any] | None = None,
|
arguments: dict[str, Any] | None = None,
|
||||||
|
_meta: dict[str, Any] | None = None,
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
) -> types.CallToolResult:
|
) -> types.CallToolResult:
|
||||||
"""Send a tools/call request."""
|
"""Send a tools/call request."""
|
||||||
|
|
@ -256,7 +257,7 @@ class ClientSession(
|
||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.CallToolRequest(
|
types.CallToolRequest(
|
||||||
method="tools/call",
|
method="tools/call",
|
||||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.CallToolResult,
|
types.CallToolResult,
|
||||||
|
|
|
||||||
|
|
@ -855,6 +855,7 @@ class CallToolRequestParams(RequestParams):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
arguments: dict[str, Any] | None = None
|
arguments: dict[str, Any] | None = None
|
||||||
|
_meta: dict[str, Any] | None = None
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
@ -49,6 +50,7 @@ class Tool(ABC):
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
app_id: str | None = None,
|
app_id: str | None = None,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
|
inputs: dict[str, Any] | None = None,
|
||||||
) -> Generator[ToolInvokeMessage]:
|
) -> Generator[ToolInvokeMessage]:
|
||||||
if self.runtime and self.runtime.runtime_parameters:
|
if self.runtime and self.runtime.runtime_parameters:
|
||||||
tool_parameters.update(self.runtime.runtime_parameters)
|
tool_parameters.update(self.runtime.runtime_parameters)
|
||||||
|
|
@ -56,13 +58,24 @@ class Tool(ABC):
|
||||||
# try parse tool parameters into the correct type
|
# try parse tool parameters into the correct type
|
||||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||||
|
|
||||||
result = self._invoke(
|
# Construct the call parameter and pass in inputs only when the _invoke of the subclass accepts inputs
|
||||||
user_id=user_id,
|
invoke_kwargs = {
|
||||||
tool_parameters=tool_parameters,
|
"user_id": user_id,
|
||||||
conversation_id=conversation_id,
|
"tool_parameters": tool_parameters,
|
||||||
app_id=app_id,
|
"conversation_id": conversation_id,
|
||||||
message_id=message_id,
|
"app_id": app_id,
|
||||||
)
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
if inputs is not None:
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(self._invoke)
|
||||||
|
if "inputs" in sig.parameters:
|
||||||
|
invoke_kwargs["inputs"] = inputs
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# fallback: Do not pass inputs if reflection fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = self._invoke(**invoke_kwargs)
|
||||||
|
|
||||||
if isinstance(result, ToolInvokeMessage):
|
if isinstance(result, ToolInvokeMessage):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,9 @@ class MCPTool(Tool):
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
app_id: str | None = None,
|
app_id: str | None = None,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
|
inputs: dict[str, Any] | None = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
result = self.invoke_remote_mcp_tool(tool_parameters, _meta=inputs)
|
||||||
# handle dify tool output
|
# handle dify tool output
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if isinstance(content, TextContent):
|
if isinstance(content, TextContent):
|
||||||
|
|
@ -141,7 +142,7 @@ class MCPTool(Tool):
|
||||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any], _meta: dict[str, Any] | None) -> CallToolResult:
|
||||||
headers = self.headers.copy() if self.headers else {}
|
headers = self.headers.copy() if self.headers else {}
|
||||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||||
|
|
||||||
|
|
@ -176,7 +177,9 @@ class MCPTool(Tool):
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
provider_entity=provider_entity,
|
provider_entity=provider_entity,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
return mcp_client.invoke_tool(
|
||||||
|
tool_name=self.entity.identity.name, tool_args=tool_parameters, _meta=_meta
|
||||||
|
)
|
||||||
except MCPConnectionError as e:
|
except MCPConnectionError as e:
|
||||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ class ToolEngine:
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
app_id: str | None = None,
|
app_id: str | None = None,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
|
inputs: dict[str, Any] | None = None,
|
||||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||||
"""
|
"""
|
||||||
Agent invokes the tool with the given arguments.
|
Agent invokes the tool with the given arguments.
|
||||||
|
|
@ -79,7 +80,7 @@ class ToolEngine:
|
||||||
# hit the callback handler
|
# hit the callback handler
|
||||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||||
|
|
||||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
|
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id, inputs)
|
||||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||||
|
|
||||||
def message_callback(
|
def message_callback(
|
||||||
|
|
@ -197,6 +198,7 @@ class ToolEngine:
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
app_id: str | None = None,
|
app_id: str | None = None,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
|
inputs: dict[str, Any] | None = None,
|
||||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the tool with the given arguments.
|
Invoke the tool with the given arguments.
|
||||||
|
|
@ -214,7 +216,7 @@ class ToolEngine:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
|
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id, inputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
meta.error = str(e)
|
meta.error = str(e)
|
||||||
raise ToolEngineInvokeError(meta)
|
raise ToolEngineInvokeError(meta)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue