From 882b4c9ef68ec5db1cebebe77582245325ab7f4c Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 4 Mar 2026 16:01:43 +0800 Subject: [PATCH] refactor: document extract node decouple ssrf_proxy (#32949) --- api/.importlinter | 1 - api/core/workflow/node_factory.py | 1 + .../nodes/document_extractor/node.py | 22 +++++++++++++------ .../nodes/test_document_extractor_node.py | 13 ++++++----- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/api/.importlinter b/api/.importlinter index 0d9af6e065..10faeb448a 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -105,7 +105,6 @@ ignore_imports = dify_graph.nodes.agent.agent_node -> core.model_manager dify_graph.nodes.agent.agent_node -> core.provider_manager dify_graph.nodes.agent.agent_node -> core.tools.tool_manager - dify_graph.nodes.document_extractor.node -> core.helper.ssrf_proxy dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.llm_utils -> core.model_manager diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 1b4937769e..714b0ca3d0 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -265,6 +265,7 @@ class DifyNodeFactory(NodeFactory): graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, unstructured_api_config=self._document_extractor_unstructured_api_config, + http_client=self._http_request_http_client, ) if node_type == NodeType.QUESTION_CLASSIFIER: diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py index 01ecd49494..5945e57926 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -20,11 +20,11 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from core.helper import ssrf_proxy from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, file_manager from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node +from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.variables import ArrayFileSegment from dify_graph.variables.segments import ArrayStringSegment, FileSegment @@ -58,6 +58,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): graph_runtime_state: "GraphRuntimeState", *, unstructured_api_config: UnstructuredApiConfig | None = None, + http_client: HttpClientProtocol, ) -> None: super().__init__( id=id, @@ -66,6 +67,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): graph_runtime_state=graph_runtime_state, ) self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() + self._http_client = http_client def _run(self): variable_selector = self.node_data.variable_selector @@ -85,7 +87,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): try: if isinstance(value, list): extracted_text_list = [ - _extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config) + _extract_text_from_file( + self._http_client, file, unstructured_api_config=self._unstructured_api_config + ) for file in value ] return NodeRunResult( @@ -95,7 +99,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): - extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config) + extracted_text = _extract_text_from_file( + self._http_client, value, unstructured_api_config=self._unstructured_api_config + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -439,13 +445,13 @@ def _extract_text_from_docx(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e -def _download_file_content(file: File) -> bytes: +def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: """Download the content of a file based on its transfer method.""" try: if file.transfer_method == FileTransferMethod.REMOTE_URL: if file.remote_url is None: raise FileDownloadError("Missing URL for remote file") - response = ssrf_proxy.get(file.remote_url) + response = http_client.get(file.remote_url) response.raise_for_status() return response.content else: @@ -454,8 +460,10 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError(f"Error downloading file: {str(e)}") from e -def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str: - file_content = _download_file_content(file) +def _extract_text_from_file( + http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig +) -> str: + file_content = _download_file_content(http_client, file) if file.extension: extracted_text = _extract_text_by_file_extension( file_content=file_content, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index a74bdd8837..dff84b580a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -43,11 +43,13 @@ def document_extractor_node(graph_init_params): variable_selector=["node_id", "variable_name"], ) node_config = {"id": "test_node_id", "data": node_data.model_dump()} + http_client = Mock() node = DocumentExtractorNode( id="test_node_id", config=node_config, graph_init_params=graph_init_params, graph_runtime_state=Mock(), + http_client=http_client, ) return node @@ -141,12 +143,13 @@ def test_run_extract_text( mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment mock_download = Mock(return_value=file_content) - mock_ssrf_proxy_get = Mock() - mock_ssrf_proxy_get.return_value.content = file_content - mock_ssrf_proxy_get.return_value.raise_for_status = Mock() + + mock_response = Mock() + mock_response.content = file_content + mock_response.raise_for_status = Mock() + document_extractor_node._http_client.get = Mock(return_value=mock_response) monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) - monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) @@ -163,7 +166,7 @@ def test_run_extract_text( assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: - mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt") elif transfer_method == FileTransferMethod.LOCAL_FILE: mock_download.assert_called_once_with(mock_file)