mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
refactor: tool node decouple db (#33166)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a480e9beb1
commit
b9d05d3456
@ -45,7 +45,6 @@ allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
|
||||
@ -111,7 +110,6 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.tool.tool_node -> models
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
@ -134,7 +132,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
|
||||
@ -14,6 +14,7 @@ import httpx
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.file.models import ToolFile as ToolFilePydanticModel
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
@ -207,7 +208,9 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
|
||||
def get_file_generator_by_tool_file_id(
|
||||
self, tool_file_id: str
|
||||
) -> tuple[Generator | None, ToolFilePydanticModel | None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -229,7 +232,7 @@ class ToolFileManager:
|
||||
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return stream, tool_file
|
||||
return stream, ToolFilePydanticModel.model_validate(tool_file)
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
|
||||
@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@ -310,6 +311,15 @@ class DifyNodeFactory(NodeFactory):
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TOOL:
|
||||
return ToolNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel):
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
class ToolFile(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file")
|
||||
user_id: UUID = Field(..., description="ID of the user who owns this file")
|
||||
tenant_id: UUID = Field(..., description="ID of the tenant/organization")
|
||||
conversation_id: UUID | None = Field(None, description="ID of the associated conversation")
|
||||
file_key: str = Field(..., max_length=255, description="Storage key for the file")
|
||||
mimetype: str = Field(..., max_length=255, description="MIME type of the file")
|
||||
original_url: str | None = Field(
|
||||
None, max_length=2048, description="Original URL if file was fetched from external source"
|
||||
)
|
||||
name: str = Field(default="", max_length=255, description="Display name of the file")
|
||||
size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)")
|
||||
|
||||
class Config:
|
||||
from_attributes = True # Enable ORM mode for SQLAlchemy compatibility
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
# NOTE: dify_model_identity is a special identifier used to distinguish between
|
||||
# new and old data formats during serialization and deserialization.
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.models import ToolFile
|
||||
|
||||
|
||||
class HttpClientProtocol(Protocol):
|
||||
@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol):
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from dify_graph.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
@ -36,7 +32,8 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node[ToolNodeData]):
|
||||
@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
tool_file_manager_factory: ToolFileManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not found")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@ -55,11 +56,14 @@ def init_tool_node(config: dict):
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
from dify_graph.nodes.template_transform import TemplateTransformNode
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@ -73,6 +73,12 @@ class MockNodeMixin:
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
kwargs.setdefault("template_renderer", _TestJinja2Renderer())
|
||||
|
||||
# Provide default tool_file_manager_factory for ToolNode subclasses
|
||||
from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles
|
||||
|
||||
if isinstance(self, _ToolNode):
|
||||
kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol))
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
|
||||
@ -31,6 +31,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
ops_stub.TraceTask = object # pragma: no cover - stub attribute
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
|
||||
graph_config: dict[str, Any] = {
|
||||
@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
config = graph_config["nodes"][0]
|
||||
|
||||
# Provide a stub ToolFileManager to satisfy the updated ToolNode constructor
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id="node-instance",
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user