fix: fix tool type is miss (#32042)

This commit is contained in:
wangxiaolei 2026-02-06 14:38:15 +08:00 committed by GitHub
parent e988266f53
commit a297b06aac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 5 additions and 239 deletions

View File

@ -102,8 +102,6 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
core.workflow.nodes.agent.agent_node -> core.db.session_factory
core.workflow.nodes.agent.agent_node -> models.tools
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.db.session_factory import session_factory
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -50,12 +49,6 @@ from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
MCPToolProvider,
WorkflowToolProvider,
)
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
@ -266,7 +259,7 @@ class AgentNode(Node[AgentNodeData]):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@ -755,34 +748,3 @@ class AgentNode(Node[AgentNodeData]):
llm_usage=llm_usage,
)
)
@staticmethod
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
provider_type_str = tool_config.get("type")
if provider_type_str:
return ToolProviderType(provider_type_str)
provider_id = tool_config.get("provider_name")
if not provider_id:
return ToolProviderType.BUILT_IN
with session_factory.create_session() as session:
provider_map: dict[
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
ToolProviderType,
] = {
WorkflowToolProvider: ToolProviderType.WORKFLOW,
MCPToolProvider: ToolProviderType.MCP,
ApiToolProvider: ToolProviderType.API,
BuiltinToolProvider: ToolProviderType.BUILT_IN,
}
for provider_model, provider_type in provider_map.items():
stmt = select(provider_model).where(
provider_model.id == provider_id,
provider_model.tenant_id == tenant_id,
)
if session.scalar(stmt):
return provider_type
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")

View File

@ -1,197 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.agent.agent_node import AgentNode
class TestInferToolProviderType:
"""Test cases for AgentNode._infer_tool_provider_type method."""
def test_infer_type_from_config_workflow(self):
"""Test inferring workflow provider type from config."""
tool_config = {
"type": "workflow",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
def test_infer_type_from_config_builtin(self):
"""Test inferring builtin provider type from config."""
tool_config = {
"type": "builtin",
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_from_config_api(self):
"""Test inferring API provider type from config."""
tool_config = {
"type": "api",
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
def test_infer_type_from_config_mcp(self):
"""Test inferring MCP provider type from config."""
tool_config = {
"type": "mcp",
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
def test_infer_type_invalid_config_value_raises_error(self):
"""Test that invalid type value in config raises ValueError."""
tool_config = {
"type": "invalid-type",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with pytest.raises(ValueError):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_workflow_type_from_database(self):
"""Test inferring workflow provider type from database."""
tool_config = {
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns a result
mock_session.scalar.return_value = True
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
# Should only query once (after finding WorkflowToolProvider)
assert mock_session.scalar.call_count == 1
def test_infer_mcp_type_from_database(self):
"""Test inferring MCP provider type from database."""
tool_config = {
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns a result
mock_session.scalar.side_effect = [None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
assert mock_session.scalar.call_count == 2
def test_infer_api_type_from_database(self):
"""Test inferring API provider type from database."""
tool_config = {
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns None
# Third query (ApiToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
assert mock_session.scalar.call_count == 3
def test_infer_builtin_type_from_database(self):
"""Test inferring builtin provider type from database."""
tool_config = {
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First three queries return None
# Fourth query (BuiltinToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
assert mock_session.scalar.call_count == 4
def test_infer_type_default_when_not_found(self):
"""Test raising AgentNodeError when provider is not found in database."""
tool_config = {
"provider_name": "unknown-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# All queries return None
mock_session.scalar.return_value = None
# Current implementation raises AgentNodeError when provider not found
from core.workflow.nodes.agent.exc import AgentNodeError
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_type_default_when_no_provider_name(self):
"""Test defaulting to BUILT_IN when provider_name is missing."""
tool_config = {}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_database_exception_propagates(self):
"""Test that database exception propagates (current implementation doesn't catch it)."""
tool_config = {
"provider_name": "provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# Database query raises exception
mock_session.scalar.side_effect = Exception("Database error")
# Current implementation doesn't catch exceptions, so it propagates
with pytest.raises(Exception, match="Database error"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)

View File

@ -109,6 +109,7 @@ const AgentTools: FC = () => {
tool_parameters: paramsWithDefaultValue,
notAuthor: !tool.is_team_authorization,
enabled: true,
type: tool.provider_type as CollectionType,
}
}
const handleSelectTool = (tool: ToolDefaultValue) => {

View File

@ -129,6 +129,7 @@ export const useToolSelectorState = ({
extra: {
description: tool.tool_description,
},
type: tool.provider_type,
}
}, [])

View File

@ -87,6 +87,7 @@ export type ToolValue = {
enabled?: boolean
extra?: { description?: string } & Record<string, unknown>
credential_id?: string
type?: string
}
export type DataSourceItem = {