mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
refactor: document extract node decouple ssrf_proxy (#32949)
This commit is contained in:
parent
b8a4e0c13b
commit
882b4c9ef6
@ -105,7 +105,6 @@ ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
|
||||
@ -265,6 +265,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
|
||||
@ -20,11 +20,11 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod, file_manager
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.variables import ArrayFileSegment
|
||||
from dify_graph.variables.segments import ArrayStringSegment, FileSegment
|
||||
|
||||
@ -58,6 +58,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
http_client: HttpClientProtocol,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -66,6 +67,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
|
||||
self._http_client = http_client
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
@ -85,7 +87,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = [
|
||||
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
|
||||
_extract_text_from_file(
|
||||
self._http_client, file, unstructured_api_config=self._unstructured_api_config
|
||||
)
|
||||
for file in value
|
||||
]
|
||||
return NodeRunResult(
|
||||
@ -95,7 +99,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
|
||||
extracted_text = _extract_text_from_file(
|
||||
self._http_client, value, unstructured_api_config=self._unstructured_api_config
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@ -439,13 +445,13 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
def _download_file_content(file: File) -> bytes:
|
||||
def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes:
|
||||
"""Download the content of a file based on its transfer method."""
|
||||
try:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if file.remote_url is None:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response = http_client.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
@ -454,8 +460,10 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
file_content = _download_file_content(file)
|
||||
def _extract_text_from_file(
|
||||
http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig
|
||||
) -> str:
|
||||
file_content = _download_file_content(http_client, file)
|
||||
if file.extension:
|
||||
extracted_text = _extract_text_by_file_extension(
|
||||
file_content=file_content,
|
||||
|
||||
@ -43,11 +43,13 @@ def document_extractor_node(graph_init_params):
|
||||
variable_selector=["node_id", "variable_name"],
|
||||
)
|
||||
node_config = {"id": "test_node_id", "data": node_data.model_dump()}
|
||||
http_client = Mock()
|
||||
node = DocumentExtractorNode(
|
||||
id="test_node_id",
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
|
||||
@ -141,12 +143,13 @@ def test_run_extract_text(
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment
|
||||
|
||||
mock_download = Mock(return_value=file_content)
|
||||
mock_ssrf_proxy_get = Mock()
|
||||
mock_ssrf_proxy_get.return_value.content = file_content
|
||||
mock_ssrf_proxy_get.return_value.raise_for_status = Mock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.content = file_content
|
||||
mock_response.raise_for_status = Mock()
|
||||
document_extractor_node._http_client.get = Mock(return_value=mock_response)
|
||||
|
||||
monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download)
|
||||
monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get)
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
@ -163,7 +166,7 @@ def test_run_extract_text(
|
||||
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
|
||||
document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
mock_download.assert_called_once_with(mock_file)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user