diff --git a/api/.importlinter b/api/.importlinter index 57773f57d6..5c0a6e1288 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -45,7 +45,6 @@ allow_indirect_imports = True ignore_imports = dify_graph.nodes.agent.agent_node -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis @@ -111,7 +110,6 @@ ignore_imports = dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.tool.tool_node -> models dify_graph.nodes.agent.agent_node -> models.model dify_graph.nodes.llm.node -> core.helper.code_executor dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors @@ -134,7 +132,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> core.tools.errors dify_graph.nodes.agent.agent_node -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.nodes.agent.agent_node -> models dify_graph.nodes.llm.node -> models.model dify_graph.nodes.agent.agent_node -> services diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index d16b919561..f6eccc734b 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -14,6 +14,7 @@ import httpx from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy +from dify_graph.file.models import ToolFile as ToolFilePydanticModel from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile @@ -207,7 +208,9 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: + def get_file_generator_by_tool_file_id( + self, tool_file_id: str + ) -> tuple[Generator | None, ToolFilePydanticModel | None]: """ get file binary @@ -229,7 +232,7 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, tool_file + return stream, ToolFilePydanticModel.model_validate(tool_file) # init tool_file_parser diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index c1475f2f18..8c6b1dedee 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, ) from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -310,6 +311,15 @@ class DifyNodeFactory(NodeFactory): memory=memory, ) + if node_type == NodeType.TOOL: + return ToolNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + tool_file_manager_factory=self._http_request_tool_file_manager_factory(), + ) + return node_class( id=node_id, config=node_config, diff --git a/api/dify_graph/file/models.py b/api/dify_graph/file/models.py index db12d4f57a..dcba00978e 100644 --- a/api/dify_graph/file/models.py +++ b/api/dify_graph/file/models.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any +from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator @@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 +class ToolFile(BaseModel): + id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") + user_id: UUID = Field(..., description="ID of the user who owns this file") + tenant_id: UUID = Field(..., description="ID of the tenant/organization") + conversation_id: UUID | None = Field(None, description="ID of the associated conversation") + file_key: str = Field(..., max_length=255, description="Storage key for the file") + mimetype: str = Field(..., max_length=255, description="MIME type of the file") + original_url: str | None = Field( + None, max_length=2048, description="Original URL if file was fetched from external source" + ) + name: str = Field(default="", max_length=255, description="Display name of the file") + size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") + + class Config: + from_attributes = True # Enable ORM mode for SQLAlchemy compatibility + populate_by_name = True + + class File(BaseModel): # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py index cc007150f1..62d3bcdca1 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -1,8 +1,10 @@ +from collections.abc import Generator from typing import Any, Protocol import httpx from dify_graph.file import File +from dify_graph.file.models import ToolFile class HttpClientProtocol(Protocol): @@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol): mimetype: str, filename: str | None = None, ) -> Any: ... + + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 57fb946559..a6e0b710f1 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -1,9 +1,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment from dify_graph.variables.variables import ArrayAnyVariable -from extensions.ext_database import db from factories import file_factory -from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import ToolNodeData @@ -36,7 +32,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.runtime import VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool class ToolNode(Node[ToolNodeData]): @@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]): node_type = NodeType.TOOL + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + tool_file_manager_factory: ToolFileManagerProtocol, + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._tool_file_manager_factory = tool_file_manager_factory + @classmethod def version(cls) -> str: return "1" @@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]): tool_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") + _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + if not tool_file: + raise ToolFileError(f"tool file {tool_file_id} not found") mapping = { "tool_file_id": tool_file_id, @@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]): assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"tool file {tool_file_id} not exists") + _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + if not tool_file: + raise ToolFileError(f"tool file {tool_file_id} not exists") mapping = { "tool_file_id": tool_file_id, diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f70bf46979..23cb56d2a5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -55,11 +56,14 @@ def init_tool_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + node = ToolNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + tool_file_manager_factory=tool_file_manager_factory, ) return node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 43fadadbc2..34e714a227 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -22,7 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode from dify_graph.nodes.llm import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.nodes.template_transform.template_renderer import ( @@ -73,6 +73,12 @@ class MockNodeMixin: if isinstance(self, TemplateTransformNode): kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + # Provide default tool_file_manager_factory for ToolNode subclasses + from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + + if isinstance(self, _ToolNode): + kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + super().__init__( id=id, config=config, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 11554169e1..3cbd96dfef 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -31,6 +31,7 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) + from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { @@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] + + # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor + tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + node = ToolNode( id="node-instance", config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + tool_file_manager_factory=tool_file_manager_factory, ) return node