fix: SSRF in WordExtractor URL download (credit to @EaEa0001 ) (#31678)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
盐粒 Yanli 2026-01-29 14:01:21 +08:00 committed by GitHub
parent c2473d85dc
commit dbfc47e8b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 13 deletions

View File

@ -104,6 +104,8 @@ def download(f: File, /):
):
return _download_file_content(f.storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
@ -134,6 +136,8 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
data = response.content

View File

@ -4,8 +4,10 @@ Proxy requests to avoid SSRF
import logging
import time
from typing import Any, TypeAlias
import httpx
from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
Headers: TypeAlias = dict[str, str]
_HEADERS_ADAPTER = TypeAdapter(Headers)
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
_SSRF_CLIENT_LIMITS = httpx.Limits(
@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
)
def _get_user_provided_host_header(headers: dict | None) -> str | None:
def _get_user_provided_host_header(headers: Headers | None) -> str | None:
"""
Extract the user-provided Host header from the headers dict.
@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
def _inject_trace_headers(headers: Headers | None) -> Headers:
"""
Inject W3C traceparent header for distributed tracing.
@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict:
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
if not isinstance(verify_option, bool):
raise ValueError("ssl_verify must be a boolean")
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
try:
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
except ValidationError as e:
raise ValueError("headers must be a mapping of string keys to string values") from e
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("GET", url, max_retries=max_retries, **kwargs)
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("POST", url, max_retries=max_retries, **kwargs)
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("PUT", url, max_retries=max_retries, **kwargs)
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("HEAD", url, max_retries=max_retries, **kwargs)

View File

@ -1,4 +1,7 @@
"""Abstract interface for document loader implementations."""
"""Word (.docx) document extractor used for RAG ingestion.
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import logging
import mimetypes
@ -8,7 +11,6 @@ import tempfile
import uuid
from urllib.parse import urlparse
import httpx
from docx import Document as DocxDocument
from docx.oxml.ns import qn
from docx.text.run import Run
@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor):
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
response = httpx.get(self.file_path, timeout=None)
response = ssrf_proxy.get(self.file_path)
if response.status_code != 200:
response.close()
@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor):
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
try:
self.temp_file.write(response.content)
self.temp_file.flush()
finally:
response.close()
self.file_path = self.temp_file.name

View File

@ -1,7 +1,9 @@
"""Primarily used for testing merged cell scenarios"""
import io
import os
import tempfile
from pathlib import Path
from types import SimpleNamespace
from docx import Document
@ -56,6 +58,42 @@ def test_parse_row():
assert extractor._parse_row(row, {}, 3) == gt[idx]
def test_init_downloads_via_ssrf_proxy(monkeypatch):
doc = Document()
doc.add_paragraph("hello")
buf = io.BytesIO()
doc.save(buf)
docx_bytes = buf.getvalue()
calls: list[tuple[str, object]] = []
class FakeResponse:
status_code = 200
content = docx_bytes
def close(self) -> None:
calls.append(("close", None))
def fake_get(url: str, **kwargs):
calls.append(("get", (url, kwargs)))
return FakeResponse()
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id")
try:
assert calls
assert calls[0][0] == "get"
url, kwargs = calls[0][1]
assert url == "https://example.com/test.docx"
assert kwargs.get("timeout") is None
assert extractor.web_path == "https://example.com/test.docx"
assert extractor.file_path != extractor.web_path
assert Path(extractor.file_path).read_bytes() == docx_bytes
finally:
extractor.temp_file.close()
def test_extract_images_from_docx(monkeypatch):
external_bytes = b"ext-bytes"
internal_bytes = b"int-bytes"