From 01f17b7ddc5378f6fc8b49c4879a5ab492f6309d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 14 Jan 2026 14:19:48 +0800 Subject: [PATCH] refactor(http_request_node): apply DI for http request node (#30509) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/helper/ssrf_proxy.py | 4 ++ .../workflow/nodes/http_request/executor.py | 30 +++++++++------ api/core/workflow/nodes/http_request/node.py | 37 +++++++++++++++++-- api/core/workflow/nodes/node_factory.py | 25 ++++++++++++- api/core/workflow/nodes/protocols.py | 29 +++++++++++++++ 5 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 api/core/workflow/nodes/protocols.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 1785cbde4c..128c64ff2c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError): pass +request_error = httpx.RequestError +max_retries_exceeded_error = MaxRetriesExceededError + + def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]: return { "http://": httpx.HTTPTransport( diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 931c6113a7..429f8411a6 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -17,6 +17,7 @@ from core.helper import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.runtime import VariablePool +from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, @@ -78,6 +79,8 @@ class Executor: timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, + http_client: HttpClientProtocol = ssrf_proxy, + file_manager: FileManagerProtocol = file_manager, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -104,6 +107,8 @@ class Executor: self.data = None self.json = None self.max_retries = max_retries + self._http_client = http_client + self._file_manager = file_manager # init template self.variable_pool = variable_pool @@ -200,7 +205,7 @@ class Executor: if file_variable is None: raise FileFetchError(f"cannot fetch file with selector {file_selector}") file = file_variable.value - self.content = file_manager.download(file) + self.content = self._file_manager.download(file) case "x-www-form-urlencoded": form_data = { self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( @@ -239,7 +244,7 @@ class Executor: ): file_tuple = ( file.filename, - file_manager.download(file), + self._file_manager.download(file), file.mime_type or "application/octet-stream", ) if key not in files: @@ -332,19 +337,18 @@ class Executor: do http request depending on api bundle """ _METHOD_MAP = { - "get": ssrf_proxy.get, - "head": ssrf_proxy.head, - "post": ssrf_proxy.post, - "put": ssrf_proxy.put, - "delete": ssrf_proxy.delete, - "patch": ssrf_proxy.patch, + "get": self._http_client.get, + "head": self._http_client.head, + "post": self._http_client.post, + "put": self._http_client.put, + "delete": self._http_client.delete, + "patch": self._http_client.patch, } method_lc = self.method.lower() if method_lc not in _METHOD_MAP: raise InvalidHttpMethodError(f"Invalid http method {self.method}") request_args = { - "url": self.url, "data": self.data, "files": self.files, "json": self.json, @@ -357,8 +361,12 @@ class Executor: } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries) - except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: + response: httpx.Response = _METHOD_MAP[method_lc]( + url=self.url, + **request_args, + max_retries=self.max_retries, + ) + except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e: raise HttpRequestNodeError(str(e)) from e # FIXME: fix type ignore, this maybe httpx type issue return response diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 9bd1cb9761..964e53e03c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,10 +1,11 @@ import logging import mimetypes -from collections.abc import Mapping, Sequence -from typing import Any +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any from configs import dify_config -from core.file import File, FileTransferMethod +from core.file import File, FileTransferMethod, file_manager +from core.helper import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.variables.segments import ArrayFileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.http_request.executor import Executor +from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol from factories import file_factory from .entities import ( @@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class HttpRequestNode(Node[HttpRequestNodeData]): node_type = NodeType.HTTP_REQUEST + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + http_client: HttpClientProtocol = ssrf_proxy, + tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, + file_manager: FileManagerProtocol = file_manager, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._http_client = http_client + self._tool_file_manager_factory = tool_file_manager_factory + self._file_manager = file_manager + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, max_retries=0, + http_client=self._http_client, + file_manager=self._file_manager, ) process_data["request"] = http_executor.to_log() @@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): mime_type = ( content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" ) - tool_file_manager = ToolFileManager() + tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( user_id=self.user_id, diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 557d3a330d..5c04e5110f 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -1,16 +1,21 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, final from typing_extensions import override from configs import dify_config +from core.file import file_manager +from core.helper import ssrf_proxy from core.helper.code_executor.code_executor import CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.tools.tool_file_manager import ToolFileManager from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol from core.workflow.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, Jinja2TemplateRenderer, @@ -43,6 +48,9 @@ class DifyNodeFactory(NodeFactory): code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_limits: CodeNodeLimits | None = None, template_renderer: Jinja2TemplateRenderer | None = None, + http_request_http_client: HttpClientProtocol = ssrf_proxy, + http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, + http_request_file_manager: FileManagerProtocol = file_manager, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state @@ -61,6 +69,9 @@ class DifyNodeFactory(NodeFactory): max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() + self._http_request_http_client = http_request_http_client + self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory + self._http_request_file_manager = http_request_file_manager @override def create_node(self, node_config: dict[str, object]) -> Node: @@ -113,6 +124,7 @@ class DifyNodeFactory(NodeFactory): code_providers=self._code_providers, code_limits=self._code_limits, ) + if node_type == NodeType.TEMPLATE_TRANSFORM: return TemplateTransformNode( id=node_id, @@ -122,6 +134,17 @@ class DifyNodeFactory(NodeFactory): template_renderer=self._template_renderer, ) + if node_type == NodeType.HTTP_REQUEST: + return HttpRequestNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + http_client=self._http_request_http_client, + tool_file_manager_factory=self._http_request_tool_file_manager_factory, + file_manager=self._http_request_file_manager, + ) + return node_class( id=node_id, config=node_config, diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py new file mode 100644 index 0000000000..e7dcf62fcf --- /dev/null +++ b/api/core/workflow/nodes/protocols.py @@ -0,0 +1,29 @@ +from typing import Protocol + +import httpx + +from core.file import File + + +class HttpClientProtocol(Protocol): + @property + def max_retries_exceeded_error(self) -> type[Exception]: ... + + @property + def request_error(self) -> type[Exception]: ... + + def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + +class FileManagerProtocol(Protocol): + def download(self, f: File, /) -> bytes: ...