mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 02:36:29 +08:00
fix: fix agent node tool type is not right (#32008)
Infer real tool type via querying relevant database tables. The root cause for incorrect `type` field is still not clear.
This commit is contained in:
parent
540e1db83c
commit
fcb53383df
@ -102,6 +102,8 @@ forbidden_modules =
|
|||||||
core.trigger
|
core.trigger
|
||||||
core.variables
|
core.variables
|
||||||
ignore_imports =
|
ignore_imports =
|
||||||
|
core.workflow.nodes.agent.agent_node -> core.db.session_factory
|
||||||
|
core.workflow.nodes.agent.agent_node -> models.tools
|
||||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, Union, cast
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.agent.plugin_entities import AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.file import File, FileTransferMethod
|
from core.file import File, FileTransferMethod
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
@ -49,6 +50,12 @@ from factories import file_factory
|
|||||||
from factories.agent_factory import get_plugin_agent_strategy
|
from factories.agent_factory import get_plugin_agent_strategy
|
||||||
from models import ToolFile
|
from models import ToolFile
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
|
from models.tools import (
|
||||||
|
ApiToolProvider,
|
||||||
|
BuiltinToolProvider,
|
||||||
|
MCPToolProvider,
|
||||||
|
WorkflowToolProvider,
|
||||||
|
)
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|
||||||
from .exc import (
|
from .exc import (
|
||||||
@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
value = cast(list[dict[str, Any]], value)
|
value = cast(list[dict[str, Any]], value)
|
||||||
tool_value = []
|
tool_value = []
|
||||||
for tool in value:
|
for tool in value:
|
||||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
|
||||||
setting_params = tool.get("settings", {})
|
setting_params = tool.get("settings", {})
|
||||||
parameters = tool.get("parameters", {})
|
parameters = tool.get("parameters", {})
|
||||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||||
@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
llm_usage=llm_usage,
|
llm_usage=llm_usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
|
||||||
|
provider_type_str = tool_config.get("type")
|
||||||
|
if provider_type_str:
|
||||||
|
return ToolProviderType(provider_type_str)
|
||||||
|
|
||||||
|
provider_id = tool_config.get("provider_name")
|
||||||
|
if not provider_id:
|
||||||
|
return ToolProviderType.BUILT_IN
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
provider_map: dict[
|
||||||
|
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
|
||||||
|
ToolProviderType,
|
||||||
|
] = {
|
||||||
|
WorkflowToolProvider: ToolProviderType.WORKFLOW,
|
||||||
|
MCPToolProvider: ToolProviderType.MCP,
|
||||||
|
ApiToolProvider: ToolProviderType.API,
|
||||||
|
BuiltinToolProvider: ToolProviderType.BUILT_IN,
|
||||||
|
}
|
||||||
|
|
||||||
|
for provider_model, provider_type in provider_map.items():
|
||||||
|
stmt = select(provider_model).where(
|
||||||
|
provider_model.id == provider_id,
|
||||||
|
provider_model.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
if session.scalar(stmt):
|
||||||
|
return provider_type
|
||||||
|
|
||||||
|
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")
|
||||||
|
|||||||
@ -0,0 +1,197 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferToolProviderType:
|
||||||
|
"""Test cases for AgentNode._infer_tool_provider_type method."""
|
||||||
|
|
||||||
|
def test_infer_type_from_config_workflow(self):
|
||||||
|
"""Test inferring workflow provider type from config."""
|
||||||
|
tool_config = {
|
||||||
|
"type": "workflow",
|
||||||
|
"provider_name": "workflow-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.WORKFLOW
|
||||||
|
|
||||||
|
def test_infer_type_from_config_builtin(self):
|
||||||
|
"""Test inferring builtin provider type from config."""
|
||||||
|
tool_config = {
|
||||||
|
"type": "builtin",
|
||||||
|
"provider_name": "builtin-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.BUILT_IN
|
||||||
|
|
||||||
|
def test_infer_type_from_config_api(self):
|
||||||
|
"""Test inferring API provider type from config."""
|
||||||
|
tool_config = {
|
||||||
|
"type": "api",
|
||||||
|
"provider_name": "api-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.API
|
||||||
|
|
||||||
|
def test_infer_type_from_config_mcp(self):
|
||||||
|
"""Test inferring MCP provider type from config."""
|
||||||
|
tool_config = {
|
||||||
|
"type": "mcp",
|
||||||
|
"provider_name": "mcp-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.MCP
|
||||||
|
|
||||||
|
def test_infer_type_invalid_config_value_raises_error(self):
|
||||||
|
"""Test that invalid type value in config raises ValueError."""
|
||||||
|
tool_config = {
|
||||||
|
"type": "invalid-type",
|
||||||
|
"provider_name": "workflow-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
def test_infer_workflow_type_from_database(self):
|
||||||
|
"""Test inferring workflow provider type from database."""
|
||||||
|
tool_config = {
|
||||||
|
"provider_name": "workflow-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# First query (WorkflowToolProvider) returns a result
|
||||||
|
mock_session.scalar.return_value = True
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.WORKFLOW
|
||||||
|
# Should only query once (after finding WorkflowToolProvider)
|
||||||
|
assert mock_session.scalar.call_count == 1
|
||||||
|
|
||||||
|
def test_infer_mcp_type_from_database(self):
|
||||||
|
"""Test inferring MCP provider type from database."""
|
||||||
|
tool_config = {
|
||||||
|
"provider_name": "mcp-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# First query (WorkflowToolProvider) returns None
|
||||||
|
# Second query (MCPToolProvider) returns a result
|
||||||
|
mock_session.scalar.side_effect = [None, True]
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.MCP
|
||||||
|
assert mock_session.scalar.call_count == 2
|
||||||
|
|
||||||
|
def test_infer_api_type_from_database(self):
|
||||||
|
"""Test inferring API provider type from database."""
|
||||||
|
tool_config = {
|
||||||
|
"provider_name": "api-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# First query (WorkflowToolProvider) returns None
|
||||||
|
# Second query (MCPToolProvider) returns None
|
||||||
|
# Third query (ApiToolProvider) returns a result
|
||||||
|
mock_session.scalar.side_effect = [None, None, True]
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.API
|
||||||
|
assert mock_session.scalar.call_count == 3
|
||||||
|
|
||||||
|
def test_infer_builtin_type_from_database(self):
|
||||||
|
"""Test inferring builtin provider type from database."""
|
||||||
|
tool_config = {
|
||||||
|
"provider_name": "builtin-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# First three queries return None
|
||||||
|
# Fourth query (BuiltinToolProvider) returns a result
|
||||||
|
mock_session.scalar.side_effect = [None, None, None, True]
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.BUILT_IN
|
||||||
|
assert mock_session.scalar.call_count == 4
|
||||||
|
|
||||||
|
def test_infer_type_default_when_not_found(self):
|
||||||
|
"""Test raising AgentNodeError when provider is not found in database."""
|
||||||
|
tool_config = {
|
||||||
|
"provider_name": "unknown-provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# All queries return None
|
||||||
|
mock_session.scalar.return_value = None
|
||||||
|
|
||||||
|
# Current implementation raises AgentNodeError when provider not found
|
||||||
|
from core.workflow.nodes.agent.exc import AgentNodeError
|
||||||
|
|
||||||
|
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
|
||||||
|
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
def test_infer_type_default_when_no_provider_name(self):
|
||||||
|
"""Test defaulting to BUILT_IN when provider_name is missing."""
|
||||||
|
tool_config = {}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
|
|
||||||
|
assert result == ToolProviderType.BUILT_IN
|
||||||
|
|
||||||
|
def test_infer_type_database_exception_propagates(self):
|
||||||
|
"""Test that database exception propagates (current implementation doesn't catch it)."""
|
||||||
|
tool_config = {
|
||||||
|
"provider_name": "provider-id",
|
||||||
|
}
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
|
||||||
|
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# Database query raises exception
|
||||||
|
mock_session.scalar.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
# Current implementation doesn't catch exceptions, so it propagates
|
||||||
|
with pytest.raises(Exception, match="Database error"):
|
||||||
|
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||||
Loading…
Reference in New Issue
Block a user