From b24e6edada70f1711b745787919e5b113dd047ad Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Fri, 6 Feb 2026 11:24:39 +0800 Subject: [PATCH] fix: fix agent node tool type is not right (#32008) Infer real tool type via querying relevant database tables. The root cause for incorrect `type` field is still not clear. --- api/.importlinter | 2 + api/core/workflow/nodes/agent/agent_node.py | 42 +++- .../core/workflow/nodes/agent/__init__.py | 0 .../workflow/nodes/agent/test_agent_node.py | 197 ++++++++++++++++++ 4 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py diff --git a/api/.importlinter b/api/.importlinter index 2a6bb66a95..fb66df7334 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -102,6 +102,8 @@ forbidden_modules = core.trigger core.variables ignore_imports = + core.workflow.nodes.agent.agent_node -> core.db.session_factory + core.workflow.nodes.agent.agent_node -> models.tools core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index e195aebe6d..e64a83034c 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Union, cast from packaging.version import Version from pydantic import ValidationError @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter +from core.db.session_factory import session_factory from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -49,6 +50,12 @@ from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy from models import ToolFile from models.model import Conversation +from models.tools import ( + ApiToolProvider, + BuiltinToolProvider, + MCPToolProvider, + WorkflowToolProvider, +) from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exc import ( @@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + provider_type = self._infer_tool_provider_type(tool, self.tenant_id) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]): llm_usage=llm_usage, ) ) + + @staticmethod + def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType: + provider_type_str = tool_config.get("type") + if provider_type_str: + return ToolProviderType(provider_type_str) + + provider_id = tool_config.get("provider_name") + if not provider_id: + return ToolProviderType.BUILT_IN + + with session_factory.create_session() as session: + provider_map: dict[ + type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]], + ToolProviderType, + ] = { + WorkflowToolProvider: ToolProviderType.WORKFLOW, + MCPToolProvider: ToolProviderType.MCP, + ApiToolProvider: ToolProviderType.API, + BuiltinToolProvider: ToolProviderType.BUILT_IN, + } + + for provider_model, provider_type in provider_map.items(): + stmt = select(provider_model).where( + provider_model.id == provider_id, + provider_model.tenant_id == tenant_id, + ) + if session.scalar(stmt): + return provider_type + + raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.") diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/__init__.py b/api/tests/unit_tests/core/workflow/nodes/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py new file mode 100644 index 0000000000..a95892d0b6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.nodes.agent.agent_node import AgentNode + + +class TestInferToolProviderType: + """Test cases for AgentNode._infer_tool_provider_type method.""" + + def test_infer_type_from_config_workflow(self): + """Test inferring workflow provider type from config.""" + tool_config = { + "type": "workflow", + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.WORKFLOW + + def test_infer_type_from_config_builtin(self): + """Test inferring builtin provider type from config.""" + tool_config = { + "type": "builtin", + "provider_name": "builtin-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + + def test_infer_type_from_config_api(self): + """Test inferring API provider type from config.""" + tool_config = { + "type": "api", + "provider_name": "api-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.API + + def test_infer_type_from_config_mcp(self): + """Test inferring MCP provider type from config.""" + tool_config = { + "type": "mcp", + "provider_name": "mcp-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.MCP + + def test_infer_type_invalid_config_value_raises_error(self): + """Test that invalid type value in config raises ValueError.""" + tool_config = { + "type": "invalid-type", + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + with pytest.raises(ValueError): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + def test_infer_workflow_type_from_database(self): + """Test inferring workflow provider type from database.""" + tool_config = { + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns a result + mock_session.scalar.return_value = True + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.WORKFLOW + # Should only query once (after finding WorkflowToolProvider) + assert mock_session.scalar.call_count == 1 + + def test_infer_mcp_type_from_database(self): + """Test inferring MCP provider type from database.""" + tool_config = { + "provider_name": "mcp-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns None + # Second query (MCPToolProvider) returns a result + mock_session.scalar.side_effect = [None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.MCP + assert mock_session.scalar.call_count == 2 + + def test_infer_api_type_from_database(self): + """Test inferring API provider type from database.""" + tool_config = { + "provider_name": "api-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns None + # Second query (MCPToolProvider) returns None + # Third query (ApiToolProvider) returns a result + mock_session.scalar.side_effect = [None, None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.API + assert mock_session.scalar.call_count == 3 + + def test_infer_builtin_type_from_database(self): + """Test inferring builtin provider type from database.""" + tool_config = { + "provider_name": "builtin-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First three queries return None + # Fourth query (BuiltinToolProvider) returns a result + mock_session.scalar.side_effect = [None, None, None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + assert mock_session.scalar.call_count == 4 + + def test_infer_type_default_when_not_found(self): + """Test raising AgentNodeError when provider is not found in database.""" + tool_config = { + "provider_name": "unknown-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # All queries return None + mock_session.scalar.return_value = None + + # Current implementation raises AgentNodeError when provider not found + from core.workflow.nodes.agent.exc import AgentNodeError + + with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + def test_infer_type_default_when_no_provider_name(self): + """Test defaulting to BUILT_IN when provider_name is missing.""" + tool_config = {} + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + + def test_infer_type_database_exception_propagates(self): + """Test that database exception propagates (current implementation doesn't catch it).""" + tool_config = { + "provider_name": "provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # Database query raises exception + mock_session.scalar.side_effect = Exception("Database error") + + # Current implementation doesn't catch exceptions, so it propagates + with pytest.raises(Exception, match="Database error"): + AgentNode._infer_tool_provider_type(tool_config, tenant_id)