refactor: tool node decouple db (#33166)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2026-03-10 01:47:15 +08:00 committed by GitHub
parent a480e9beb1
commit b9d05d3456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 81 additions and 22 deletions

View File

@ -45,7 +45,6 @@ allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
@ -111,7 +110,6 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.tool.tool_node -> models
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
@ -134,7 +132,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.agent.agent_node -> services

View File

@ -14,6 +14,7 @@ import httpx
from configs import dify_config
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from dify_graph.file.models import ToolFile as ToolFilePydanticModel
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
@ -207,7 +208,9 @@ class ToolFileManager:
return blob, tool_file.mimetype
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
def get_file_generator_by_tool_file_id(
self, tool_file_id: str
) -> tuple[Generator | None, ToolFilePydanticModel | None]:
"""
get file binary
@ -229,7 +232,7 @@ class ToolFileManager:
stream = storage.load_stream(tool_file.file_key)
return stream, tool_file
return stream, ToolFilePydanticModel.model_validate(tool_file)
# init tool_file_parser

View File

@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.nodes.tool.tool_node import ToolNode
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@ -310,6 +311,15 @@ class DifyNodeFactory(NodeFactory):
memory=memory,
)
if node_type == NodeType.TOOL:
return ToolNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
)
return node_class(
id=node_id,
config=node_config,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, model_validator
@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel):
number_limits: int = 0
class ToolFile(BaseModel):
id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file")
user_id: UUID = Field(..., description="ID of the user who owns this file")
tenant_id: UUID = Field(..., description="ID of the tenant/organization")
conversation_id: UUID | None = Field(None, description="ID of the associated conversation")
file_key: str = Field(..., max_length=255, description="Storage key for the file")
mimetype: str = Field(..., max_length=255, description="MIME type of the file")
original_url: str | None = Field(
None, max_length=2048, description="Original URL if file was fetched from external source"
)
name: str = Field(default="", max_length=255, description="Display name of the file")
size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)")
class Config:
from_attributes = True # Enable ORM mode for SQLAlchemy compatibility
populate_by_name = True
class File(BaseModel):
# NOTE: dify_model_identity is a special identifier used to distinguish between
# new and old data formats during serialization and deserialization.

View File

@ -1,8 +1,10 @@
from collections.abc import Generator
from typing import Any, Protocol
import httpx
from dify_graph.file import File
from dify_graph.file.models import ToolFile
class HttpClientProtocol(Protocol):
@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol):
mimetype: str,
filename: str | None = None,
) -> Any: ...
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...

View File

@ -1,9 +1,6 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
from dify_graph.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
@ -36,7 +32,8 @@ from .exc import (
)
if TYPE_CHECKING:
from dify_graph.runtime import VariablePool
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState, VariablePool
class ToolNode(Node[ToolNodeData]):
@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
node_type = NodeType.TOOL
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
tool_file_manager_factory: ToolFileManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._tool_file_manager_factory = tool_file_manager_factory
@classmethod
def version(cls) -> str:
return "1"
@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not found")
mapping = {
"tool_file_id": tool_file_id,
@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,

View File

@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.nodes.tool.tool_node import ToolNode
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@ -55,11 +56,14 @@ def init_tool_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
id=str(uuid.uuid4()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
)
return node

View File

@ -22,7 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
from dify_graph.nodes.template_transform import TemplateTransformNode
from dify_graph.nodes.template_transform.template_renderer import (
@ -73,6 +73,12 @@ class MockNodeMixin:
if isinstance(self, TemplateTransformNode):
kwargs.setdefault("template_renderer", _TestJinja2Renderer())
# Provide default tool_file_manager_factory for ToolNode subclasses
from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles
if isinstance(self, _ToolNode):
kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol))
super().__init__(
id=id,
config=config,

View File

@ -31,6 +31,7 @@ def tool_node(monkeypatch) -> ToolNode:
ops_stub.TraceTask = object # pragma: no cover - stub attribute
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.nodes.tool.tool_node import ToolNode
graph_config: dict[str, Any] = {
@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
# Provide a stub ToolFileManager to satisfy the updated ToolNode constructor
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
id="node-instance",
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
)
return node