diff --git a/api/.importlinter b/api/.importlinter index fb66df7334..2a6bb66a95 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -102,8 +102,6 @@ forbidden_modules = core.trigger core.variables ignore_imports = - core.workflow.nodes.agent.agent_node -> core.db.session_factory - core.workflow.nodes.agent.agent_node -> models.tools core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index e64a83034c..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast from packaging.version import Version from pydantic import ValidationError @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter -from core.db.session_factory import session_factory from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -50,12 +49,6 @@ from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy from models import ToolFile from models.model import Conversation -from models.tools import ( - ApiToolProvider, - BuiltinToolProvider, - MCPToolProvider, - WorkflowToolProvider, -) from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exc import ( @@ -266,7 +259,7 @@ class AgentNode(Node[AgentNodeData]): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = self._infer_tool_provider_type(tool, self.tenant_id) + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -755,34 +748,3 @@ class AgentNode(Node[AgentNodeData]): llm_usage=llm_usage, ) ) - - @staticmethod - def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType: - provider_type_str = tool_config.get("type") - if provider_type_str: - return ToolProviderType(provider_type_str) - - provider_id = tool_config.get("provider_name") - if not provider_id: - return ToolProviderType.BUILT_IN - - with session_factory.create_session() as session: - provider_map: dict[ - type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]], - ToolProviderType, - ] = { - WorkflowToolProvider: ToolProviderType.WORKFLOW, - MCPToolProvider: ToolProviderType.MCP, - ApiToolProvider: ToolProviderType.API, - BuiltinToolProvider: ToolProviderType.BUILT_IN, - } - - for provider_model, provider_type in provider_map.items(): - stmt = select(provider_model).where( - provider_model.id == provider_id, - provider_model.tenant_id == tenant_id, - ) - if session.scalar(stmt): - return provider_type - - raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.") diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py deleted file mode 100644 index a95892d0b6..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py +++ /dev/null @@ -1,197 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.nodes.agent.agent_node import AgentNode - - -class TestInferToolProviderType: - """Test cases for AgentNode._infer_tool_provider_type method.""" - - def test_infer_type_from_config_workflow(self): - """Test inferring workflow provider type from config.""" - tool_config = { - "type": "workflow", - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.WORKFLOW - - def test_infer_type_from_config_builtin(self): - """Test inferring builtin provider type from config.""" - tool_config = { - "type": "builtin", - "provider_name": "builtin-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - - def test_infer_type_from_config_api(self): - """Test inferring API provider type from config.""" - tool_config = { - "type": "api", - "provider_name": "api-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.API - - def test_infer_type_from_config_mcp(self): - """Test inferring MCP provider type from config.""" - tool_config = { - "type": "mcp", - "provider_name": "mcp-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.MCP - - def test_infer_type_invalid_config_value_raises_error(self): - """Test that invalid type value in config raises ValueError.""" - tool_config = { - "type": "invalid-type", - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - with pytest.raises(ValueError): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - def test_infer_workflow_type_from_database(self): - """Test inferring workflow provider type from database.""" - tool_config = { - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns a result - mock_session.scalar.return_value = True - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.WORKFLOW - # Should only query once (after finding WorkflowToolProvider) - assert mock_session.scalar.call_count == 1 - - def test_infer_mcp_type_from_database(self): - """Test inferring MCP provider type from database.""" - tool_config = { - "provider_name": "mcp-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns None - # Second query (MCPToolProvider) returns a result - mock_session.scalar.side_effect = [None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.MCP - assert mock_session.scalar.call_count == 2 - - def test_infer_api_type_from_database(self): - """Test inferring API provider type from database.""" - tool_config = { - "provider_name": "api-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns None - # Second query (MCPToolProvider) returns None - # Third query (ApiToolProvider) returns a result - mock_session.scalar.side_effect = [None, None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.API - assert mock_session.scalar.call_count == 3 - - def test_infer_builtin_type_from_database(self): - """Test inferring builtin provider type from database.""" - tool_config = { - "provider_name": "builtin-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First three queries return None - # Fourth query (BuiltinToolProvider) returns a result - mock_session.scalar.side_effect = [None, None, None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - assert mock_session.scalar.call_count == 4 - - def test_infer_type_default_when_not_found(self): - """Test raising AgentNodeError when provider is not found in database.""" - tool_config = { - "provider_name": "unknown-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # All queries return None - mock_session.scalar.return_value = None - - # Current implementation raises AgentNodeError when provider not found - from core.workflow.nodes.agent.exc import AgentNodeError - - with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - def test_infer_type_default_when_no_provider_name(self): - """Test defaulting to BUILT_IN when provider_name is missing.""" - tool_config = {} - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - - def test_infer_type_database_exception_propagates(self): - """Test that database exception propagates (current implementation doesn't catch it).""" - tool_config = { - "provider_name": "provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # Database query raises exception - mock_session.scalar.side_effect = Exception("Database error") - - # Current implementation doesn't catch exceptions, so it propagates - with pytest.raises(Exception, match="Database error"): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 486c0a8ac9..b97aa6e775 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -109,6 +109,7 @@ const AgentTools: FC = () => { tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, + type: tool.provider_type as CollectionType, } } const handleSelectTool = (tool: ToolDefaultValue) => { diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts index 44d0ff864e..e5edea5679 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts @@ -129,6 +129,7 @@ export const useToolSelectorState = ({ extra: { description: tool.tool_description, }, + type: tool.provider_type, } }, []) diff --git a/web/app/components/workflow/block-selector/types.ts b/web/app/components/workflow/block-selector/types.ts index 07efb0d02f..39e7b033bd 100644 --- a/web/app/components/workflow/block-selector/types.ts +++ b/web/app/components/workflow/block-selector/types.ts @@ -87,6 +87,7 @@ export type ToolValue = { enabled?: boolean extra?: { description?: string } & Record credential_id?: string + type?: string } export type DataSourceItem = {