mirror of https://github.com/langgenius/dify.git
fix: SSRF in WordExtractor URL download (credit to @EaEa0001 ) (#31678)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
c2473d85dc
commit
dbfc47e8b0
|
|
@ -104,6 +104,8 @@ def download(f: File, /):
|
|||
):
|
||||
return _download_file_content(f.storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
|
@ -134,6 +136,8 @@ def _download_file_content(path: str, /):
|
|||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ Proxy requests to avoid SSRF
|
|||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
|
@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
|||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
Headers: TypeAlias = dict[str, str]
|
||||
_HEADERS_ADAPTER = TypeAdapter(Headers)
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
_SSRF_CLIENT_LIMITS = httpx.Limits(
|
||||
|
|
@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
|||
)
|
||||
|
||||
|
||||
def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
def _get_user_provided_host_header(headers: Headers | None) -> str | None:
|
||||
"""
|
||||
Extract the user-provided Host header from the headers dict.
|
||||
|
||||
|
|
@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
def _inject_trace_headers(headers: Headers | None) -> Headers:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
|
|
@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict:
|
|||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
|
|
@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
|
||||
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
if not isinstance(verify_option, bool):
|
||||
raise ValueError("ssl_verify must be a boolean")
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
try:
|
||||
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
|
||||
except ValidationError as e:
|
||||
raise ValueError("headers must be a mapping of string keys to string values") from e
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
|
|
@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("GET", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("POST", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("PUT", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
"""Word (.docx) document extractor used for RAG ingestion.
|
||||
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
|
|
@ -8,7 +11,6 @@ import tempfile
|
|||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
|
@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
response = httpx.get(self.file_path, timeout=None)
|
||||
response = ssrf_proxy.get(self.file_path)
|
||||
|
||||
if response.status_code != 200:
|
||||
response.close()
|
||||
|
|
@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor):
|
|||
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
|
||||
try:
|
||||
self.temp_file.write(response.content)
|
||||
self.temp_file.flush()
|
||||
finally:
|
||||
response.close()
|
||||
self.file_path = self.temp_file.name
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from docx import Document
|
||||
|
|
@ -56,6 +58,42 @@ def test_parse_row():
|
|||
assert extractor._parse_row(row, {}, 3) == gt[idx]
|
||||
|
||||
|
||||
def test_init_downloads_via_ssrf_proxy(monkeypatch):
|
||||
doc = Document()
|
||||
doc.add_paragraph("hello")
|
||||
buf = io.BytesIO()
|
||||
doc.save(buf)
|
||||
docx_bytes = buf.getvalue()
|
||||
|
||||
calls: list[tuple[str, object]] = []
|
||||
|
||||
class FakeResponse:
|
||||
status_code = 200
|
||||
content = docx_bytes
|
||||
|
||||
def close(self) -> None:
|
||||
calls.append(("close", None))
|
||||
|
||||
def fake_get(url: str, **kwargs):
|
||||
calls.append(("get", (url, kwargs)))
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
|
||||
|
||||
extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id")
|
||||
try:
|
||||
assert calls
|
||||
assert calls[0][0] == "get"
|
||||
url, kwargs = calls[0][1]
|
||||
assert url == "https://example.com/test.docx"
|
||||
assert kwargs.get("timeout") is None
|
||||
assert extractor.web_path == "https://example.com/test.docx"
|
||||
assert extractor.file_path != extractor.web_path
|
||||
assert Path(extractor.file_path).read_bytes() == docx_bytes
|
||||
finally:
|
||||
extractor.temp_file.close()
|
||||
|
||||
|
||||
def test_extract_images_from_docx(monkeypatch):
|
||||
external_bytes = b"ext-bytes"
|
||||
internal_bytes = b"int-bytes"
|
||||
|
|
|
|||
Loading…
Reference in New Issue