fix(api): centralize remote file retrieval (#36399)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2026-06-01 17:25:08 +08:00 committed by GitHub
parent cfc1cf2b8c
commit 71ffaacb58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1368 additions and 293 deletions

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
resp = remote_fetcher.make_request("HEAD", decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
# Try to fetch remote file metadata/content first
try:
resp = ssrf_proxy.head(url=url)
resp = remote_fetcher.make_request("HEAD", url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -9,7 +9,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
HTTPException: If the remote file cannot be accessed
"""
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
resp = remote_fetcher.make_request("HEAD", decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)
resp = remote_fetcher.make_request("HEAD", url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Literal, override
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.file import remote_fetcher
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
@ -46,7 +46,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
@override
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
return remote_fetcher.graphon_remote_file_fetcher.get(url, follow_redirects=follow_redirects)
@override
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:

View File

@ -12,7 +12,7 @@ from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
@ -44,26 +44,6 @@ class DatasourceFileManager:
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
recalculated_sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
@ -117,7 +97,7 @@ class DatasourceFileManager:
) -> ToolFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response = remote_fetcher.make_request("GET", file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:

View File

@ -0,0 +1,5 @@
"""File retrieval helpers shared by backend file-oriented workflows."""
from . import remote_fetcher
__all__ = ["remote_fetcher"]

View File

@ -0,0 +1,345 @@
"""Unified remote-file retrieval with Dify signed file URL resolution.
Use this module for backend workflows whose intent is to fetch remote file content
or remote file metadata from a URL, even when the URL originally came from a user
upload, a workflow variable, a tool/datasource file, or an app DSL. GET/HEAD
requests can resolve Dify-signed file URLs locally through DB + storage before
falling back to the SSRF-protected network client.
Use `core.helper.ssrf_proxy` directly only for generic outbound HTTP where the
URL is not being treated as a remote file, such as HTTP Request nodes, external
API integrations, auth discovery, or user-configured tool calls. Those calls must
stay as real network requests and should not reinterpret Dify file URLs as stored
files.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import re
import time
import urllib.parse
from dataclasses import dataclass
from typing import Any, Literal
import httpx
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
_to_graphon_http_response,
max_retries_exceeded_error,
request_error,
)
from extensions.ext_storage import storage
from models import ToolFile, UploadFile
_UPLOAD_FILE_PATH_PATTERN = re.compile(
r"^/files/(?P<file_id>[a-fA-F0-9-]+)/(?P<preview_kind>file-preview|image-preview)$"
)
_TOOL_FILE_PATH_PATTERN = re.compile(r"^/files/tools/(?P<file_id>[a-fA-F0-9-]+)(?P<extension>\.[^/]*)?$")
_DATASOURCE_FILE_PATH_PATTERN = re.compile(r"^/files/datasources/(?P<file_id>[a-fA-F0-9-]+)(?P<extension>\.[^/]*)?$")
_file_access_controller = DatabaseFileAccessController()
@dataclass(frozen=True)
class _SignedFileUrl:
file_id: str
preview_kind: Literal["file-preview", "image-preview"]
record_kind: Literal["upload", "tool", "datasource"]
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
"""Fetch remote file content or metadata.
GET and HEAD requests for Dify-owned signed file URLs are served from local
storage. Every other request is delegated unchanged to the SSRF proxy.
"""
normalized_method = method.upper()
if normalized_method == "GET":
response = _resolve_dify_signed_file_url("GET", url)
if response is not None:
return response
if normalized_method == "HEAD":
response = _resolve_dify_signed_file_url("HEAD", url)
if response is not None:
return response
return ssrf_proxy.make_request(method=method, url=url, max_retries=max_retries, **kwargs)
class GraphonRemoteFileFetcher:
"""Graphon HTTP-client adapter backed by the unified remote-file fetcher.
Graphon requires method-specific HTTP client methods, while regular Dify
call sites should use `make_request` directly.
"""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("GET", url=url, max_retries=max_retries, **kwargs))
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("HEAD", url=url, max_retries=max_retries, **kwargs))
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("POST", url=url, max_retries=max_retries, **kwargs))
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("PUT", url=url, max_retries=max_retries, **kwargs))
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("DELETE", url=url, max_retries=max_retries, **kwargs))
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
return _to_graphon_http_response(make_request("PATCH", url=url, max_retries=max_retries, **kwargs))
def _resolve_dify_signed_file_url(method: Literal["GET", "HEAD"], url: str) -> httpx.Response | None:
parsed_url = urllib.parse.urlparse(url)
if not _is_dify_file_origin(parsed_url):
return None
signed_file_url = _parse_signed_file_path(parsed_url.path)
if signed_file_url is None:
return None
query = urllib.parse.parse_qs(parsed_url.query, keep_blank_values=True)
timestamp = _single_query_value(query, "timestamp")
nonce = _single_query_value(query, "nonce")
sign = _single_query_value(query, "sign")
if timestamp is None or nonce is None or sign is None:
return None
if not _verify_signed_file_url(
signed_file_url=signed_file_url,
timestamp=timestamp,
nonce=nonce,
sign=sign,
):
return None
if signed_file_url.record_kind == "upload":
return _build_upload_file_response(method=method, url=url, file_id=signed_file_url.file_id)
if signed_file_url.record_kind == "tool":
return _build_tool_file_response(method=method, url=url, file_id=signed_file_url.file_id)
return _build_datasource_file_response(method=method, url=url, file_id=signed_file_url.file_id)
def _parse_signed_file_path(path: str) -> _SignedFileUrl | None:
upload_match = _UPLOAD_FILE_PATH_PATTERN.match(path)
if upload_match:
preview_kind: Literal["file-preview", "image-preview"]
if upload_match.group("preview_kind") == "image-preview":
preview_kind = "image-preview"
else:
preview_kind = "file-preview"
return _SignedFileUrl(
file_id=upload_match.group("file_id"),
preview_kind=preview_kind,
record_kind="upload",
)
tool_match = _TOOL_FILE_PATH_PATTERN.match(path)
if tool_match:
return _SignedFileUrl(
file_id=tool_match.group("file_id"),
preview_kind="file-preview",
record_kind="tool",
)
datasource_match = _DATASOURCE_FILE_PATH_PATTERN.match(path)
if datasource_match:
return _SignedFileUrl(
file_id=datasource_match.group("file_id"),
preview_kind="file-preview",
record_kind="datasource",
)
return None
def _is_dify_file_origin(parsed_url: urllib.parse.ParseResult) -> bool:
if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname:
return False
url_origin = _origin_parts(parsed_url)
if url_origin is None:
return False
allowed_origins = {
origin
for configured_url in [dify_config.FILES_URL, dify_config.INTERNAL_FILES_URL]
if configured_url and (origin := _origin_parts(urllib.parse.urlparse(configured_url))) is not None
}
return url_origin in allowed_origins
def _origin_parts(parsed_url: urllib.parse.ParseResult) -> tuple[str, str, int] | None:
if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname:
return None
try:
port = parsed_url.port
except ValueError:
return None
return parsed_url.scheme, parsed_url.hostname.lower(), port or _default_port(parsed_url.scheme)
def _default_port(scheme: str) -> int:
return 443 if scheme == "https" else 80
def _single_query_value(query: dict[str, list[str]], key: str) -> str | None:
values = query.get(key)
if not values or len(values) != 1:
return None
return values[0]
def _verify_signed_file_url(
*,
signed_file_url: _SignedFileUrl,
timestamp: str,
nonce: str,
sign: str,
) -> bool:
try:
current_time = int(time.time())
signed_at = int(timestamp)
except ValueError:
return False
if current_time - signed_at > dify_config.FILES_ACCESS_TIMEOUT:
return False
payload = f"{signed_file_url.preview_kind}|{signed_file_url.file_id}|{timestamp}|{nonce}"
recalculated = hmac.new(dify_config.SECRET_KEY.encode(), payload.encode(), hashlib.sha256).digest()
expected = base64.urlsafe_b64encode(recalculated).decode()
return hmac.compare_digest(sign, expected)
def _build_upload_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
with session_factory.create_session() as session:
upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id)
if upload_file is None:
return _build_response(method=method, url=url, status_code=404)
content = b"" if method == "HEAD" else storage.load_once(upload_file.key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=upload_file.size,
content_type=upload_file.mime_type,
filename=upload_file.name,
)
def _build_tool_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
with session_factory.create_session() as session:
tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id)
if tool_file is None:
return _build_response(method=method, url=url, status_code=404)
content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=tool_file.size,
content_type=tool_file.mimetype,
filename=tool_file.name,
)
def _build_datasource_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
with session_factory.create_session() as session:
upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id)
if upload_file is not None:
return _build_upload_file_record_response(method=method, url=url, upload_file=upload_file)
tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id)
if tool_file is not None:
return _build_tool_file_record_response(method=method, url=url, tool_file=tool_file)
return _build_response(method=method, url=url, status_code=404)
def _build_upload_file_record_response(
*,
method: Literal["GET", "HEAD"],
url: str,
upload_file: UploadFile,
) -> httpx.Response:
content = b"" if method == "HEAD" else storage.load_once(upload_file.key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=upload_file.size,
content_type=upload_file.mime_type,
filename=upload_file.name,
)
def _build_tool_file_record_response(
*,
method: Literal["GET", "HEAD"],
url: str,
tool_file: ToolFile,
) -> httpx.Response:
content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key)
return _build_response(
method=method,
url=url,
status_code=200,
content=content,
content_length=tool_file.size,
content_type=tool_file.mimetype,
filename=tool_file.name,
)
def _build_response(
*,
method: Literal["GET", "HEAD"],
url: str,
status_code: int,
content: bytes = b"",
content_length: int | None = None,
content_type: str | None = None,
filename: str | None = None,
) -> httpx.Response:
headers: dict[str, str] = {}
if content_type:
headers["Content-Type"] = content_type
if content_length is not None and content_length >= 0:
headers["Content-Length"] = str(content_length)
if filename:
headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{urllib.parse.quote(filename)}"
return httpx.Response(
status_code=status_code,
headers=headers,
content=content,
request=httpx.Request(method, url),
)
graphon_remote_file_fetcher = GraphonRemoteFileFetcher()

View File

@ -1,8 +1,7 @@
from core.helper import ssrf_proxy
def download_with_size_limit(url, max_download_size: int, **kwargs):
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
from core.file import remote_fetcher
response = remote_fetcher.make_request("GET", url, follow_redirects=True, **kwargs)
if response.status_code == 404:
raise ValueError("file not found")

View File

@ -1,5 +1,13 @@
"""
Proxy requests to avoid SSRF
"""SSRF-protected HTTP client for generic outbound requests.
Use this module when the URL represents a normal external HTTP interaction that
must go through network/proxy policy exactly as requested, such as HTTP Request
nodes, provider/API integrations, auth discovery, or custom tool calls.
Do not use this directly for "remote file" retrieval. File downloads, probes,
and metadata checks should use `core.file.remote_fetcher` instead so Dify-signed
file URLs can be resolved through DB + storage before falling back to this SSRF
client.
"""
import logging

View File

@ -5,7 +5,7 @@ from typing import Union
from urllib.parse import unquote
from configs import dify_config
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -55,7 +55,7 @@ class ExtractProcessor:
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
response = remote_fetcher.make_request("GET", url, headers={"User-Agent": USER_AGENT})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix

View File

@ -1,6 +1,6 @@
"""Word (.docx) document extractor used for RAG ingestion.
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
Supports local file paths and remote URLs downloaded through the unified remote-file fetcher.
"""
import inspect
@ -17,7 +17,7 @@ from docx.oxml.ns import qn
from docx.text.run import Run
from configs import dify_config
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
@ -51,7 +51,7 @@ class WordExtractor(BaseExtractor):
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
response = ssrf_proxy.get(self.file_path)
response = remote_fetcher.make_request("GET", self.file_path)
if response.status_code != 200:
response.close()
@ -120,7 +120,7 @@ class WordExtractor(BaseExtractor):
if not self._is_valid_url(url):
continue
try:
response = ssrf_proxy.get(url)
response = remote_fetcher.make_request("GET", url)
except Exception as e:
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
continue

View File

@ -15,7 +15,7 @@ from sqlalchemy import select
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
@ -243,7 +243,7 @@ class BaseIndexProcessor(ABC):
try:
# Download with timeout
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
response = remote_fetcher.make_request("GET", image_url, timeout=DOWNLOAD_TIMEOUT)
response.raise_for_status()
# Check Content-Length header if available

View File

@ -13,7 +13,7 @@ from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.workflow.file_reference import build_file_reference
from extensions.ext_storage import storage
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
@ -60,26 +60,6 @@ class ToolFileManager:
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
recalculated_sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def create_file_by_raw(
self,
*,
@ -129,7 +109,7 @@ class ToolFileManager:
) -> ToolFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response = remote_fetcher.make_request("GET", file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:

View File

@ -9,7 +9,7 @@ import charset_normalizer
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
@ -38,7 +38,7 @@ def get_url(url: str, user_agent: str | None = None) -> str:
main_content_type = None
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
response = remote_fetcher.make_request("HEAD", url, headers=headers, follow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
@ -60,10 +60,10 @@ def get_url(url: str, user_agent: str | None = None) -> str:
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
response = remote_fetcher.make_request("GET", url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request
scraper.perform_request = remote_fetcher.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:

View File

@ -11,6 +11,7 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
from core.db.session_factory import session_factory
from core.file import remote_fetcher
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@ -307,6 +308,7 @@ class DifyNodeFactory(NodeFactory):
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = graphon_ssrf_proxy
self._remote_file_http_client = remote_fetcher.graphon_remote_file_fetcher
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@ -318,7 +320,7 @@ class DifyNodeFactory(NodeFactory):
)
self._llm_file_saver = build_dify_llm_file_saver(
run_context=self._dify_context,
http_client=self._http_request_http_client,
http_client=self._remote_file_http_client,
conversation_id_getter=self._conversation_id,
)
self._human_input_runtime = DifyHumanInputNodeRuntime(
@ -416,7 +418,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
"http_client": self._http_request_http_client,
"http_client": self._remote_file_http_client,
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
@ -530,7 +532,7 @@ class DifyNodeFactory(NodeFactory):
if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER:
node_init_kwargs["template_renderer"] = self._jinja2_template_renderer
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
node_init_kwargs["http_client"] = self._remote_file_http_client
if include_llm_file_saver:
node_init_kwargs["llm_file_saver"] = self._llm_file_saver
if include_prompt_message_serializer:

View File

@ -16,7 +16,7 @@ import uuid
import httpx
from werkzeug.http import parse_options_header
from core.helper import ssrf_proxy
from core.file import remote_fetcher
def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None:
@ -81,7 +81,7 @@ def get_remote_file_info(url: str) -> tuple[str, str, int]:
filename = os.path.basename(url_path)
mime_type = _guess_mime_type(filename)
resp = ssrf_proxy.head(url, follow_redirects=True)
resp = remote_fetcher.make_request("HEAD", url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
content_disposition = resp.headers.get("Content-Disposition")
extracted_filename = extract_filename(url_path, content_disposition)

View File

@ -17,7 +17,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.dsl_version import CURRENT_APP_DSL_VERSION
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.plugin.entities.plugin import PluginDependency
from core.trigger.constants import (
TRIGGER_PLUGIN_NODE_TYPE,
@ -127,7 +127,7 @@ class AppDslService:
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response = remote_fetcher.make_request("GET", yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()

View File

@ -17,7 +17,7 @@ from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.file import remote_fetcher
from core.helper.name_generator import generate_incremental_name
from core.plugin.entities.plugin import PluginDependency
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -125,7 +125,7 @@ class RagPipelineDslService:
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response = remote_fetcher.make_request("GET", yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()

View File

@ -316,9 +316,9 @@ class TestAppDslService:
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
monkeypatch.setattr(
app_dsl_service.ssrf_proxy,
"get",
lambda _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")),
app_dsl_service.remote_fetcher,
"make_request",
lambda _method, _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")),
)
service = AppDslService(db_session_with_containers)
@ -336,7 +336,7 @@ class TestAppDslService:
response = MagicMock()
response.content = b""
response.raise_for_status.return_value = None
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response)
monkeypatch.setattr(app_dsl_service.remote_fetcher, "make_request", lambda _method, _url, **_kw: response)
service = AppDslService(db_session_with_containers)
result = service.import_app(
@ -353,7 +353,7 @@ class TestAppDslService:
response = MagicMock()
response.content = b"x" * (DSL_MAX_SIZE + 1)
response.raise_for_status.return_value = None
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response)
monkeypatch.setattr(app_dsl_service.remote_fetcher, "make_request", lambda _method, _url, **_kw: response)
service = AppDslService(db_session_with_containers)
result = service.import_app(
@ -372,14 +372,15 @@ class TestAppDslService:
requested_urls: list[str] = []
def fake_get(url: str, **kwargs):
def fake_make_request(method: str, url: str, **kwargs):
assert method == "GET"
requested_urls.append(url)
response = MagicMock()
response.content = yaml_bytes
response.raise_for_status.return_value = None
return response
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(app_dsl_service.remote_fetcher, "make_request", fake_make_request)
service = AppDslService(db_session_with_containers)
result = service.import_app(
@ -401,7 +402,8 @@ class TestAppDslService:
requested_urls: list[str] = []
def fake_get(url: str, **kwargs):
def fake_make_request(method: str, url: str, **kwargs):
assert method == "GET"
requested_urls.append(url)
assert url == raw_url
response = MagicMock()
@ -409,7 +411,7 @@ class TestAppDslService:
response.raise_for_status.return_value = None
return response
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(app_dsl_service.remote_fetcher, "make_request", fake_make_request)
service = AppDslService(db_session_with_containers)
result = service.import_app(

View File

@ -98,17 +98,14 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest
headers={"Content-Type": "text/plain", "Content-Length": "128"},
method="HEAD",
)
head_mock = MagicMock(return_value=head_resp)
get_mock = MagicMock()
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
make_request = MagicMock(return_value=head_resp)
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "text/plain", "file_length": 128}
head_mock.assert_called_once_with(decoded_url)
get_mock.assert_not_called()
make_request.assert_called_once_with("HEAD", decoded_url)
def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -122,15 +119,14 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch:
headers={"Content-Type": "text/plain", "Content-Length": "128"},
method="HEAD",
)
head_mock = MagicMock(return_value=head_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
make_request = MagicMock(return_value=head_resp)
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
with app.test_request_context(f"/remote-files/{target_url}?{query}", method="GET"):
payload = handler(api, url=target_url)
assert payload == {"file_type": "text/plain", "file_length": 128}
head_mock.assert_called_once_with(f"{target_url}?{query}")
make_request.assert_called_once_with("HEAD", f"{target_url}?{query}")
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -139,15 +135,21 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
make_request = MagicMock(
side_effect=[
_FakeResponse(status_code=503),
_FakeResponse(status_code=200, headers={}, method="GET"),
]
)
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
get_mock.assert_called_once_with(decoded_url, timeout=3)
assert make_request.call_args_list[0].args == ("HEAD", decoded_url)
assert make_request.call_args_list[1].args == ("GET", decoded_url)
assert make_request.call_args_list[1].kwargs == {"timeout": 3}
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -155,10 +157,9 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
handler = _unwrap(api.post)
url = "https://example.com/report.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
get_mock = MagicMock(return_value=get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
make_request = MagicMock(side_effect=[_FakeResponse(status_code=404), get_resp])
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
@ -178,7 +179,10 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
assert status == 201
assert payload["id"] == "file-1"
assert payload["url"] == "https://signed.example/file-1"
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
assert make_request.call_args_list[0].args == ("HEAD",)
assert make_request.call_args_list[0].kwargs == {"url": url}
assert make_request.call_args_list[1].args == ("GET",)
assert make_request.call_args_list[1].kwargs == {"url": url, "timeout": 3, "follow_redirects": True}
file_service_cls.return_value.upload_file.assert_called_once_with(
filename="report.txt",
content=b"fallback-content",
@ -195,14 +199,10 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
handler = _unwrap(api.post)
url = "https://example.com/photo.jpg"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
)
head_resp = _FakeResponse(status_code=200, method="HEAD", content=b"head-content")
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
get_mock = MagicMock(return_value=extra_get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
make_request = MagicMock(side_effect=[head_resp, extra_get_resp])
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
@ -221,7 +221,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
assert status == 201
assert payload["id"] == "file-2"
get_mock.assert_called_once_with(url)
assert make_request.call_args_list[1].args == ("GET", url)
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
@ -230,12 +230,13 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"get",
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
make_request = MagicMock(
side_effect=[
_FakeResponse(status_code=500),
_FakeResponse(status_code=502, text="bad gateway"),
]
)
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
@ -248,11 +249,8 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
url = "https://example.com/fail.txt"
request = httpx.Request("HEAD", url)
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
)
make_request = MagicMock(side_effect=httpx.RequestError("network down", request=request))
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
@ -264,12 +262,8 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
_, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
@ -283,12 +277,8 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
@ -302,12 +292,8 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke
handler = _unwrap(api.post)
url = "https://example.com/file.exe"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
monkeypatch.setattr(remote_files_module.remote_fetcher, "make_request", make_request)
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()

View File

@ -25,26 +25,26 @@ def _end_user() -> SimpleNamespace:
# RemoteFileInfoApi
# ---------------------------------------------------------------------------
class TestRemoteFileInfoApi:
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
mock_proxy.head.return_value = mock_resp
mock_proxy.make_request.return_value = mock_resp
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
assert result["file_type"] == "application/pdf"
assert result["file_length"] == 1024
mock_proxy.head.assert_called_once_with("https://example.com/file.pdf")
mock_proxy.make_request.assert_called_once_with("HEAD", "https://example.com/file.pdf")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
def test_preserves_unencoded_target_query(self, mock_proxy: MagicMock, app: Flask) -> None:
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain", "Content-Length": "128"}
mock_proxy.head.return_value = mock_resp
mock_proxy.make_request.return_value = mock_resp
target_url = "http://example.com/api/aiagent/httpview/txt"
query = "fileNameKey=cankao1_ce4305bc-be20-4c5d-8732-de1741d28e27"
@ -53,14 +53,14 @@ class TestRemoteFileInfoApi:
result = RemoteFileInfoApi().get(_app_model(), _end_user(), target_url)
assert result["file_type"] == "text/plain"
mock_proxy.head.assert_called_once_with(f"{target_url}?{query}")
mock_proxy.make_request.assert_called_once_with("HEAD", f"{target_url}?{query}")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
def test_preserves_encoded_target_query(self, mock_proxy: MagicMock, app: Flask) -> None:
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain", "Content-Length": "128"}
mock_proxy.head.return_value = mock_resp
mock_proxy.make_request.return_value = mock_resp
target_url = "http://example.com/api/aiagent/httpview/txt?fileNameKey=cankao1"
encoded_url = urllib.parse.quote(target_url, safe="")
@ -69,9 +69,9 @@ class TestRemoteFileInfoApi:
result = RemoteFileInfoApi().get(_app_model(), _end_user(), encoded_url)
assert result["file_type"] == "text/plain"
mock_proxy.head.assert_called_once_with(target_url)
mock_proxy.make_request.assert_called_once_with("HEAD", target_url)
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
head_resp = MagicMock()
head_resp.status_code = 405 # Method not allowed
@ -79,14 +79,13 @@ class TestRemoteFileInfoApi:
get_resp.status_code = 200
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
get_resp.raise_for_status = MagicMock()
mock_proxy.head.return_value = head_resp
mock_proxy.get.return_value = get_resp
mock_proxy.make_request.side_effect = [head_resp, get_resp]
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
assert result["file_type"] == "text/plain"
mock_proxy.get.assert_called_once()
assert mock_proxy.make_request.call_args_list[1].args == ("GET", "https://example.com/file.txt")
# ---------------------------------------------------------------------------
@ -96,7 +95,7 @@ class TestRemoteFileUploadApi:
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
@patch("controllers.web.remote_files.FileService")
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
@patch("controllers.web.remote_files.web_ns")
@patch("controllers.web.remote_files.db")
def test_upload_success(
@ -115,10 +114,9 @@ class TestRemoteFileUploadApi:
head_resp.status_code = 200
head_resp.content = b"pdf-content"
head_resp.request.method = "HEAD"
mock_proxy.head.return_value = head_resp
get_resp = MagicMock()
get_resp.content = b"pdf-content"
mock_proxy.get.return_value = get_resp
mock_proxy.make_request.side_effect = [head_resp, get_resp]
mock_guess.return_value = SimpleNamespace(
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
@ -146,7 +144,7 @@ class TestRemoteFileUploadApi:
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
@patch("controllers.web.remote_files.web_ns")
def test_file_too_large(
self,
@ -159,7 +157,7 @@ class TestRemoteFileUploadApi:
mock_ns.payload = {"url": "https://example.com/big.zip"}
head_resp = MagicMock()
head_resp.status_code = 200
mock_proxy.head.return_value = head_resp
mock_proxy.make_request.return_value = head_resp
mock_guess.return_value = SimpleNamespace(
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
)
@ -168,13 +166,13 @@ class TestRemoteFileUploadApi:
with pytest.raises(FileTooLargeError):
RemoteFileUploadApi().post(_app_model(), _end_user())
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.remote_fetcher")
@patch("controllers.web.remote_files.web_ns")
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
import httpx
mock_ns.payload = {"url": "https://example.com/bad"}
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
mock_proxy.make_request.side_effect = httpx.RequestError("connection failed")
with app.test_request_context("/remote-files/upload", method="POST"):
with pytest.raises(RemoteFileUploadError):

View File

@ -351,7 +351,11 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
with patch.object(
file_runtime.remote_fetcher.graphon_remote_file_fetcher,
"get",
return_value="response",
) as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)

View File

@ -1,6 +1,3 @@
import base64
import hashlib
import hmac
from unittest.mock import MagicMock, patch
import httpx
@ -34,34 +31,6 @@ class TestDatasourceFileManager:
assert f"nonce={mock_urandom.return_value.hex()}" in signed_url
assert "sign=" in signed_url
@patch("core.datasource.datasource_file_manager.time.time")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_verify_file(self, mock_config, mock_time):
# Setup
mock_config.SECRET_KEY = "test_secret"
mock_config.FILES_ACCESS_TIMEOUT = 300
mock_time.return_value = 1700000000
datasource_file_id = "file_id_123"
timestamp = "1699999800" # 200 seconds ago
nonce = "some_nonce"
# Manually calculate sign
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = b"test_secret"
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
# Execute & Verify Success
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
# Verify Failure - Wrong Sign
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False
# Verify Failure - Timeout
mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout)
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@ -170,7 +139,7 @@ class TestDatasourceFileManager:
assert upload_file.name == "unique_hex.pdf"
assert upload_file.extension == ".pdf"
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.remote_fetcher")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@ -180,7 +149,7 @@ class TestDatasourceFileManager:
mock_response = MagicMock()
mock_response.content = b"bits"
mock_response.headers = {} # No content-type in headers
mock_ssrf.get.return_value = mock_response
mock_ssrf.make_request.return_value = mock_response
# Execute
tool_file = DatasourceFileManager.create_file_by_url(
@ -190,7 +159,7 @@ class TestDatasourceFileManager:
# Verify
assert tool_file.mimetype == "image/png" # Guessed from .png in URL
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.remote_fetcher")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@ -200,7 +169,7 @@ class TestDatasourceFileManager:
mock_response = MagicMock()
mock_response.content = b"bits"
mock_response.headers = {}
mock_ssrf.get.return_value = mock_response
mock_ssrf.make_request.return_value = mock_response
# Execute
tool_file = DatasourceFileManager.create_file_by_url(
@ -212,7 +181,7 @@ class TestDatasourceFileManager:
# Verify
assert tool_file.mimetype == "application/octet-stream"
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.remote_fetcher")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@ -222,7 +191,7 @@ class TestDatasourceFileManager:
mock_response = MagicMock()
mock_response.content = b"downloaded bits"
mock_response.headers = {"Content-Type": "image/jpeg"}
mock_ssrf.get.return_value = mock_response
mock_ssrf.make_request.return_value = mock_response
# Execute
tool_file = DatasourceFileManager.create_file_by_url(
@ -235,10 +204,10 @@ class TestDatasourceFileManager:
assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg"
mock_storage.save.assert_called_once()
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.remote_fetcher")
def test_create_file_by_url_timeout(self, mock_ssrf):
# Setup
mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout")
mock_ssrf.make_request.side_effect = httpx.TimeoutException("Timeout")
# Execute & Verify
with pytest.raises(ValueError, match="timeout when downloading file"):

View File

@ -0,0 +1,799 @@
import base64
import hashlib
import hmac
import urllib.parse
from types import SimpleNamespace
from unittest.mock import MagicMock
import httpx
import pytest
from core.datasource.datasource_file_manager import DatasourceFileManager
from core.file import remote_fetcher
from core.tools.signature import sign_tool_file, sign_upload_file_preview_url
from core.tools.tool_file_manager import ToolFileManager
UPLOAD_FILE_ID = "1602650a-4fe4-423c-85a2-af76c083e3c4"
TOOL_FILE_ID = "2602650a-4fe4-423c-85a2-af76c083e3c4"
DATASOURCE_FILE_ID = "3602650a-4fe4-423c-85a2-af76c083e3c4"
def _signed_url(*, base_url: str, path: str, payload: str, secret: str = "test-secret") -> str:
timestamp = "1700000000"
nonce = "nonce"
signature = hmac.new(
secret.encode(),
f"{payload}|{timestamp}|{nonce}".encode(),
hashlib.sha256,
).digest()
query = urllib.parse.urlencode(
{
"timestamp": timestamp,
"nonce": nonce,
"sign": base64.urlsafe_b64encode(signature).decode(),
}
)
return f"{base_url}{path}?{query}"
def _patch_file_fetcher_config(monkeypatch):
monkeypatch.setattr(remote_fetcher.dify_config, "FILES_URL", "http://localhost:5001")
monkeypatch.setattr(remote_fetcher.dify_config, "INTERNAL_FILES_URL", "http://api:5001")
monkeypatch.setattr(remote_fetcher.dify_config, "SECRET_KEY", "test-secret")
monkeypatch.setattr(remote_fetcher.dify_config, "FILES_ACCESS_TIMEOUT", 3600)
monkeypatch.setattr(remote_fetcher.time, "time", lambda: 1700000100)
def _patch_session(monkeypatch):
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = False
monkeypatch.setattr(remote_fetcher.session_factory, "create_session", MagicMock(return_value=session_cm))
return session
def _patch_ssrf_make_request(monkeypatch, response=None):
make_request = MagicMock(return_value=response) if response is not None else MagicMock()
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "make_request", make_request)
return make_request
def _patch_signer_times(monkeypatch):
monkeypatch.setattr("core.datasource.datasource_file_manager.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000)
def test_get_signed_upload_file_url_reads_storage_without_ssrf(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
session = _patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=UPLOAD_FILE_ID,
key="upload_files/tenant/hello.txt",
name="hello.txt",
mime_type="text/plain",
size=5,
extension="txt",
)
get_upload_file = MagicMock(return_value=upload_file)
load_once = MagicMock(return_value=b"hello")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"hello"
assert response.headers["Content-Type"] == "text/plain"
assert response.headers["Content-Length"] == "5"
assert response.request.method == "GET"
get_upload_file.assert_called_once_with(
session=session,
file_id=UPLOAD_FILE_ID,
)
load_once.assert_called_once_with("upload_files/tenant/hello.txt")
ssrf_make_request.assert_not_called()
def test_make_request_resolves_upload_preview_url_generated_by_signer(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_signer_times(monkeypatch)
session = _patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=UPLOAD_FILE_ID,
key="upload_files/tenant/image.png",
name="image.png",
mime_type="image/png",
size=6,
extension=".png",
)
get_upload_file = MagicMock(return_value=upload_file)
load_once = MagicMock(return_value=b"image!")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = MagicMock()
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "make_request", ssrf_make_request)
url = sign_upload_file_preview_url(UPLOAD_FILE_ID, ".png")
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"image!"
assert response.headers["Content-Type"] == "image/png"
get_upload_file.assert_called_once_with(session=session, file_id=UPLOAD_FILE_ID)
load_once.assert_called_once_with("upload_files/tenant/image.png")
ssrf_make_request.assert_not_called()
def test_make_request_resolves_sign_tool_file_url_with_empty_extension(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_signer_times(monkeypatch)
session = _patch_session(monkeypatch)
tool_file = SimpleNamespace(
id=TOOL_FILE_ID,
file_key="tools/tenant/no-extension",
name="no-extension",
mimetype="application/octet-stream",
size=8,
)
get_tool_file = MagicMock(return_value=tool_file)
load_once = MagicMock(return_value=b"tooldata")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = MagicMock()
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "make_request", ssrf_make_request)
url = sign_tool_file(TOOL_FILE_ID, "", for_external=True)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"tooldata"
assert response.headers["Content-Type"] == "application/octet-stream"
get_tool_file.assert_called_once_with(session=session, file_id=TOOL_FILE_ID)
load_once.assert_called_once_with("tools/tenant/no-extension")
ssrf_make_request.assert_not_called()
def test_make_request_resolves_tool_manager_url_with_empty_extension(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_signer_times(monkeypatch)
session = _patch_session(monkeypatch)
tool_file = SimpleNamespace(
id=TOOL_FILE_ID,
file_key="tools/tenant/manager-file",
name="manager-file",
mimetype="application/octet-stream",
size=12,
)
get_tool_file = MagicMock(return_value=tool_file)
load_once = MagicMock(return_value=b"manager-data")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = MagicMock()
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "make_request", ssrf_make_request)
url = ToolFileManager.sign_file(TOOL_FILE_ID, "")
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"manager-data"
get_tool_file.assert_called_once_with(session=session, file_id=TOOL_FILE_ID)
load_once.assert_called_once_with("tools/tenant/manager-file")
ssrf_make_request.assert_not_called()
def test_make_request_resolves_datasource_manager_url_with_empty_extension(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_signer_times(monkeypatch)
_patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=DATASOURCE_FILE_ID,
key="datasources/tenant/no-extension",
name="no-extension",
mime_type="application/octet-stream",
size=10,
extension="",
)
get_upload_file = MagicMock(return_value=upload_file)
get_tool_file = MagicMock()
load_once = MagicMock(return_value=b"datasource")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = MagicMock()
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "make_request", ssrf_make_request)
url = DatasourceFileManager.sign_file(DATASOURCE_FILE_ID, "")
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"datasource"
assert response.headers["Content-Type"] == "application/octet-stream"
get_upload_file.assert_called_once()
get_tool_file.assert_not_called()
load_once.assert_called_once_with("datasources/tenant/no-extension")
ssrf_make_request.assert_not_called()
def test_head_signed_upload_file_url_returns_metadata_without_storage_content(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
session = _patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=UPLOAD_FILE_ID,
key="upload_files/tenant/hello.txt",
name="hello.txt",
mime_type="text/plain",
size=5,
extension="txt",
)
get_upload_file = MagicMock(return_value=upload_file)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
load_once = MagicMock(return_value=b"hello")
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("HEAD", url)
assert response.status_code == 200
assert response.content == b""
assert response.headers["Content-Type"] == "text/plain"
assert response.headers["Content-Length"] == "5"
assert response.request.method == "HEAD"
get_upload_file.assert_called_once_with(
session=session,
file_id=UPLOAD_FILE_ID,
)
load_once.assert_not_called()
ssrf_make_request.assert_not_called()
def test_make_request_get_signed_upload_file_url_reads_storage_without_ssrf(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=UPLOAD_FILE_ID,
key="upload_files/tenant/hello.txt",
name="hello.txt",
mime_type="text/plain",
size=5,
extension="txt",
)
get_upload_file = MagicMock(return_value=upload_file)
load_once = MagicMock(return_value=b"hello")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"hello"
assert response.request.method == "GET"
get_upload_file.assert_called_once()
load_once.assert_called_once_with("upload_files/tenant/hello.txt")
ssrf_make_request.assert_not_called()
def test_make_request_head_signed_upload_file_url_returns_metadata_without_ssrf(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=UPLOAD_FILE_ID,
key="upload_files/tenant/hello.txt",
name="hello.txt",
mime_type="text/plain",
size=5,
extension="txt",
)
get_upload_file = MagicMock(return_value=upload_file)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
load_once = MagicMock(return_value=b"hello")
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("HEAD", url)
assert response.status_code == 200
assert response.content == b""
assert response.headers["Content-Type"] == "text/plain"
assert response.headers["Content-Length"] == "5"
assert response.request.method == "HEAD"
get_upload_file.assert_called_once()
load_once.assert_not_called()
ssrf_make_request.assert_not_called()
def test_make_request_get_unsigned_dify_url_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
get_upload_file = MagicMock()
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
url = f"http://localhost:5001/files/{UPLOAD_FILE_ID}/file-preview?timestamp=1700000000&nonce=nonce"
proxy_response = httpx.Response(403, request=httpx.Request("GET", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url, timeout=3)
assert response is proxy_response
get_upload_file.assert_not_called()
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
timeout=3,
)
def test_make_request_post_signed_upload_file_url_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
get_upload_file = MagicMock()
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
proxy_response = httpx.Response(201, request=httpx.Request("POST", f"http://localhost:5001/files/{UPLOAD_FILE_ID}"))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("POST", url, json={"name": "ignored"})
assert response is proxy_response
get_upload_file.assert_not_called()
ssrf_make_request.assert_called_once_with(
method="POST",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
json={"name": "ignored"},
)
def test_get_signed_image_preview_url_uses_image_preview_signature(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=UPLOAD_FILE_ID,
key="upload_files/tenant/image.png",
name="image.png",
mime_type="image/png",
size=6,
extension="png",
)
get_upload_file = MagicMock(return_value=upload_file)
load_once = MagicMock(return_value=b"image!")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/image-preview",
payload=f"image-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"image!"
assert response.headers["Content-Type"] == "image/png"
get_upload_file.assert_called_once()
load_once.assert_called_once_with("upload_files/tenant/image.png")
ssrf_make_request.assert_not_called()
def test_image_preview_url_with_file_preview_signature_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
proxy_response = httpx.Response(403, request=httpx.Request("GET", "http://localhost:5001/bad"))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/image-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_duplicate_signature_query_value_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
url = (
_signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
+ "&sign=second"
)
proxy_response = httpx.Response(403, request=httpx.Request("GET", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_malformed_timestamp_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
).replace("timestamp=1700000000", "timestamp=not-an-int")
proxy_response = httpx.Response(403, request=httpx.Request("GET", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_expired_signature_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
monkeypatch.setattr(remote_fetcher.time, "time", lambda: 1700004001)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
proxy_response = httpx.Response(403, request=httpx.Request("GET", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_invalid_signature_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
proxy_response = httpx.Response(403, request=httpx.Request("GET", "http://localhost:5001/bad"))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
url = f"http://localhost:5001/files/{UPLOAD_FILE_ID}/file-preview?timestamp=1700000000&nonce=nonce&sign=bad"
response = remote_fetcher.make_request("GET", url, timeout=3)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
timeout=3,
)
def test_host_mismatch_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
url = _signed_url(
base_url="http://example.com",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
proxy_response = httpx.Response(200, request=httpx.Request("GET", url), content=b"remote")
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_unsupported_dify_path_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/not-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
proxy_response = httpx.Response(404, request=httpx.Request("HEAD", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("HEAD", url, follow_redirects=True)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="HEAD",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
follow_redirects=True,
)
def test_invalid_url_scheme_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
url = f"file:///tmp/files/{UPLOAD_FILE_ID}/file-preview?timestamp=1700000000&nonce=nonce&sign=ignored"
proxy_response = httpx.Response(403, request=httpx.Request("GET", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_invalid_url_port_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
url = f"http://localhost:invalid/files/{UPLOAD_FILE_ID}/file-preview?timestamp=1700000000&nonce=nonce&sign=ignored"
proxy_response = httpx.Response(403, request=httpx.Request("GET", "http://proxy.example/fallback"))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_invalid_configured_file_origin_delegates_to_ssrf_proxy(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
monkeypatch.setattr(remote_fetcher.dify_config, "FILES_URL", "")
monkeypatch.setattr(remote_fetcher.dify_config, "INTERNAL_FILES_URL", "file:///tmp/files")
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
proxy_response = httpx.Response(403, request=httpx.Request("GET", url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request("GET", url)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(
method="GET",
url=url,
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
)
def test_signed_upload_file_url_returns_404_when_record_missing(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
get_upload_file = MagicMock(return_value=None)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
payload=f"file-preview|{UPLOAD_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 404
assert response.content == b""
get_upload_file.assert_called_once()
ssrf_make_request.assert_not_called()
def test_get_signed_tool_file_url_reads_storage_without_ssrf(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
session = _patch_session(monkeypatch)
tool_file = SimpleNamespace(
id=TOOL_FILE_ID,
file_key="tools/tenant/result.txt",
name="result.txt",
mimetype="text/plain",
size=6,
)
get_tool_file = MagicMock(return_value=tool_file)
load_once = MagicMock(return_value=b"result")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/tools/{TOOL_FILE_ID}.txt",
payload=f"file-preview|{TOOL_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"result"
assert response.headers["Content-Type"] == "text/plain"
get_tool_file.assert_called_once_with(
session=session,
file_id=TOOL_FILE_ID,
)
load_once.assert_called_once_with("tools/tenant/result.txt")
ssrf_make_request.assert_not_called()
def test_signed_tool_file_url_returns_404_when_record_missing(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
get_tool_file = MagicMock(return_value=None)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/tools/{TOOL_FILE_ID}.txt",
payload=f"file-preview|{TOOL_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 404
assert response.content == b""
get_tool_file.assert_called_once()
ssrf_make_request.assert_not_called()
def test_get_signed_datasource_file_url_reads_upload_storage_without_ssrf(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
upload_file = SimpleNamespace(
id=DATASOURCE_FILE_ID,
key="datasources/tenant/data.txt",
name="data.txt",
mime_type="text/plain",
size=4,
extension="txt",
)
get_upload_file = MagicMock(return_value=upload_file)
get_tool_file = MagicMock(return_value=None)
load_once = MagicMock(return_value=b"data")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/datasources/{DATASOURCE_FILE_ID}.txt",
payload=f"file-preview|{DATASOURCE_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"data"
get_upload_file.assert_called_once()
get_tool_file.assert_not_called()
load_once.assert_called_once_with("datasources/tenant/data.txt")
ssrf_make_request.assert_not_called()
def test_get_signed_datasource_file_url_reads_tool_storage_when_upload_missing(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
tool_file = SimpleNamespace(
id=DATASOURCE_FILE_ID,
file_key="datasources/tenant/tool-data.txt",
name="tool-data.txt",
mimetype="text/plain",
size=9,
)
get_upload_file = MagicMock(return_value=None)
get_tool_file = MagicMock(return_value=tool_file)
load_once = MagicMock(return_value=b"tool-data")
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/datasources/{DATASOURCE_FILE_ID}.txt",
payload=f"file-preview|{DATASOURCE_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 200
assert response.content == b"tool-data"
assert response.headers["Content-Type"] == "text/plain"
assert response.headers["Content-Length"] == "9"
get_upload_file.assert_called_once()
get_tool_file.assert_called_once()
load_once.assert_called_once_with("datasources/tenant/tool-data.txt")
ssrf_make_request.assert_not_called()
def test_signed_datasource_file_url_returns_404_when_records_missing(monkeypatch):
_patch_file_fetcher_config(monkeypatch)
_patch_session(monkeypatch)
get_upload_file = MagicMock(return_value=None)
get_tool_file = MagicMock(return_value=None)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", get_upload_file)
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", get_tool_file)
ssrf_make_request = _patch_ssrf_make_request(monkeypatch)
url = _signed_url(
base_url="http://localhost:5001",
path=f"/files/datasources/{DATASOURCE_FILE_ID}.txt",
payload=f"file-preview|{DATASOURCE_FILE_ID}",
)
response = remote_fetcher.make_request("GET", url)
assert response.status_code == 404
assert response.content == b""
get_upload_file.assert_called_once()
get_tool_file.assert_called_once()
ssrf_make_request.assert_not_called()
@pytest.mark.parametrize("method_name", ["POST", "PUT", "DELETE", "PATCH"])
def test_non_get_make_request_methods_delegate_to_ssrf_proxy(monkeypatch, method_name):
url = "https://example.com/file.txt"
proxy_response = httpx.Response(200, request=httpx.Request(method_name, url))
ssrf_make_request = _patch_ssrf_make_request(monkeypatch, proxy_response)
response = remote_fetcher.make_request(method_name, url, max_retries=2, timeout=3)
assert response is proxy_response
ssrf_make_request.assert_called_once_with(method=method_name, url=url, max_retries=2, timeout=3)
def test_graphon_remote_file_fetcher_exposes_ssrf_error_types():
fetcher = remote_fetcher.GraphonRemoteFileFetcher()
assert fetcher.max_retries_exceeded_error is remote_fetcher.max_retries_exceeded_error
assert fetcher.request_error is remote_fetcher.request_error
@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
def test_graphon_remote_file_fetcher_adapts_fetcher_responses(monkeypatch, method_name):
url = "https://example.com/file.txt"
response = httpx.Response(200, request=httpx.Request(method_name.upper(), url), content=b"ok")
make_request = MagicMock(return_value=response)
graphon_response = object()
adapter = MagicMock(return_value=graphon_response)
monkeypatch.setattr(remote_fetcher, "make_request", make_request)
monkeypatch.setattr(remote_fetcher, "_to_graphon_http_response", adapter)
result = getattr(remote_fetcher.GraphonRemoteFileFetcher(), method_name)(url, max_retries=2, timeout=3)
assert result is graphon_response
make_request.assert_called_once_with(method_name.upper(), url=url, max_retries=2, timeout=3)
adapter.assert_called_once_with(response)

View File

@ -17,16 +17,19 @@ class _StubResponse:
def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None:
response = _StubResponse(status_code=200, chunks=[b"ab", b"cd", b"ef"])
mock_get = mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
mock_get = mocker.patch("core.file.remote_fetcher.make_request", return_value=response)
content = download_with_size_limit("https://example.com/a.txt", max_download_size=6, timeout=10)
assert content == b"abcdef"
mock_get.assert_called_once_with("https://example.com/a.txt", follow_redirects=True, timeout=10)
mock_get.assert_called_once_with("GET", "https://example.com/a.txt", follow_redirects=True, timeout=10)
def test_download_with_size_limit_raises_for_404(mocker: MockerFixture) -> None:
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=_StubResponse(status_code=404, chunks=[]))
mocker.patch(
"core.file.remote_fetcher.make_request",
return_value=_StubResponse(status_code=404, chunks=[]),
)
with pytest.raises(ValueError, match="file not found"):
download_with_size_limit("https://example.com/missing.txt", max_download_size=10)
@ -36,7 +39,7 @@ def test_download_with_size_limit_raises_when_size_exceeds_limit(
mocker: MockerFixture,
) -> None:
response = _StubResponse(status_code=200, chunks=[b"abc", b"de"])
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
mocker.patch("core.file.remote_fetcher.make_request", return_value=response)
with pytest.raises(ValueError, match="Max file size reached"):
download_with_size_limit("https://example.com/large.bin", max_download_size=4)
@ -46,7 +49,7 @@ def test_download_with_size_limit_accepts_content_equal_to_limit(
mocker: MockerFixture,
) -> None:
response = _StubResponse(status_code=200, chunks=[b"ab", b"cd"])
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
mocker.patch("core.file.remote_fetcher.make_request", return_value=response)
content = download_with_size_limit("https://example.com/exact.bin", max_download_size=4)

View File

@ -97,7 +97,7 @@ class TestExtractProcessorLoaders:
self, monkeypatch: pytest.MonkeyPatch, url, headers, expected_suffix
):
response = SimpleNamespace(headers=headers, content=b"body")
monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response)
monkeypatch.setattr(processor_module.remote_fetcher, "make_request", lambda *args, **kwargs: response)
monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs))
captured = {}

View File

@ -61,14 +61,14 @@ def test_parse_row():
assert extractor._parse_row(row, {}, 3) == gt[idx]
def test_init_downloads_via_ssrf_proxy(monkeypatch: pytest.MonkeyPatch):
def test_init_downloads_via_remote_fetcher(monkeypatch: pytest.MonkeyPatch):
doc = Document()
doc.add_paragraph("hello")
buf = io.BytesIO()
doc.save(buf)
docx_bytes = buf.getvalue()
calls: list[tuple[str, object]] = []
calls: list[tuple[str, tuple[str, dict[str, object]] | None]] = []
class FakeResponse:
status_code = 200
@ -77,17 +77,20 @@ def test_init_downloads_via_ssrf_proxy(monkeypatch: pytest.MonkeyPatch):
def close(self) -> None:
calls.append(("close", None))
def fake_get(url: str, **kwargs):
def fake_make_request(method: str, url: str, **kwargs):
assert method == "GET"
calls.append(("get", (url, kwargs)))
return FakeResponse()
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
monkeypatch.setattr(we, "remote_fetcher", SimpleNamespace(make_request=fake_make_request))
extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id")
try:
assert calls
assert calls[0][0] == "get"
url, kwargs = calls[0][1]
first_call = calls[0][1]
assert first_call is not None
url, kwargs = first_call
assert url == "https://example.com/test.docx"
assert kwargs.get("timeout") is None
assert extractor.web_path == "https://example.com/test.docx"
@ -139,11 +142,12 @@ def test_extract_images_from_docx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(we, "UploadFile", FakeUploadFile)
# Patch external image fetcher
def fake_get(url: str, **kwargs):
def fake_make_request(method: str, url: str, **kwargs):
assert method == "GET"
assert url == "https://example.com/image.png"
return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes)
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
monkeypatch.setattr(we, "remote_fetcher", SimpleNamespace(make_request=fake_make_request))
# A hashable internal part object with a blob attribute
class HashablePart:
@ -327,7 +331,7 @@ def test_init_rejects_invalid_url_status(monkeypatch: pytest.MonkeyPatch):
self.closed = True
fake_response = FakeResponse()
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response))
monkeypatch.setattr(we, "remote_fetcher", SimpleNamespace(make_request=lambda method, url, **kwargs: fake_response))
with pytest.raises(ValueError, match="returned status code 404"):
WordExtractor("https://example.com/missing.docx", "tenant", "user")
@ -416,12 +420,13 @@ def test_extract_images_handles_invalid_external_cases(monkeypatch: pytest.Monke
)
)
def fake_get(url, **kwargs):
def fake_make_request(method, url, **kwargs):
assert method == "GET"
if "image-error" in url:
raise RuntimeError("network")
return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x")
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
monkeypatch.setattr(we, "remote_fetcher", SimpleNamespace(make_request=fake_make_request))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock()))
monkeypatch.setattr(we, "db", db_stub)
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None))

View File

@ -200,7 +200,7 @@ class TestBaseIndexProcessor:
mock_db.engine = Mock()
with (
patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response),
patch("core.rag.index_processor.index_processor_base.remote_fetcher.make_request", return_value=response),
patch("core.rag.index_processor.index_processor_base.db", mock_db),
patch("services.file_service.FileService") as mock_file_service,
):
@ -215,7 +215,7 @@ class TestBaseIndexProcessor:
too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"}
too_large.raise_for_status.return_value = None
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large):
with patch("core.rag.index_processor.index_processor_base.remote_fetcher.make_request", return_value=too_large):
assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None
empty = Mock()
@ -223,7 +223,7 @@ class TestBaseIndexProcessor:
empty.raise_for_status.return_value = None
empty.iter_bytes.return_value = []
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty):
with patch("core.rag.index_processor.index_processor_base.remote_fetcher.make_request", return_value=empty):
assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None
def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None:
@ -232,7 +232,7 @@ class TestBaseIndexProcessor:
response.raise_for_status.return_value = None
response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)]
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response):
with patch("core.rag.index_processor.index_processor_base.remote_fetcher.make_request", return_value=response):
assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None
def test_download_image_handles_timeout_request_and_unexpected_errors(
@ -241,19 +241,19 @@ class TestBaseIndexProcessor:
request = httpx.Request("GET", "https://example.com/image.png")
with patch(
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
"core.rag.index_processor.index_processor_base.remote_fetcher.make_request",
side_effect=httpx.TimeoutException("timeout"),
):
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
with patch(
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
"core.rag.index_processor.index_processor_base.remote_fetcher.make_request",
side_effect=httpx.RequestError("bad request", request=request),
):
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
with patch(
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
"core.rag.index_processor.index_processor_base.remote_fetcher.make_request",
side_effect=RuntimeError("unexpected"),
):
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None

View File

@ -1,8 +1,8 @@
"""Unit tests for `ToolFileManager` behavior.
Covers signing/verification, file persistence flows, and retrieval APIs with
mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to
avoid real IO.
Covers signing, file persistence flows, and retrieval APIs with mocked
storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to avoid real
IO.
"""
from __future__ import annotations
@ -17,18 +17,6 @@ from core.tools.tool_file_manager import ToolFileManager
from graphon.file import FileTransferMethod
def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]:
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16)
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret")
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100)
url = ToolFileManager.sign_file("tf-1", ".png")
return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&"))
def _patch_session_factory(session: Mock):
session_cm = MagicMock()
session_cm.__enter__.return_value = session
@ -36,27 +24,10 @@ def _patch_session_factory(session: Mock):
return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm)
def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None:
query = _setup_tool_file_signing(monkeypatch)
def test_tool_file_manager_sign_file_builds_url() -> None:
url = ToolFileManager.sign_file("tf-1", ".png")
assert "/files/tools/tf-1.png" in url
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True
def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None:
query = _setup_tool_file_signing(monkeypatch)
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False
def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
query = _setup_tool_file_signing(monkeypatch)
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0)
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100)
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False
def test_create_file_by_raw_stores_file_and_persists_record() -> None:
manager = ToolFileManager()
@ -106,7 +77,7 @@ def test_create_file_by_url_downloads_and_persists_record() -> None:
patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory),
patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")),
_patch_session_factory(session),
patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response),
patch("core.tools.tool_file_manager.remote_fetcher.make_request", return_value=response),
):
file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
@ -120,7 +91,10 @@ def test_create_file_by_url_downloads_and_persists_record() -> None:
def test_create_file_by_url_raises_on_timeout() -> None:
manager = ToolFileManager()
with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")):
with patch(
"core.tools.tool_file_manager.remote_fetcher.make_request",
side_effect=httpx.TimeoutException("timeout"),
):
with pytest.raises(ValueError, match="timeout when downloading file"):
manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")

View File

@ -50,6 +50,17 @@ def stub_support_types(monkeypatch: pytest.MonkeyPatch):
return mod
def _patch_remote_fetcher(monkeypatch: pytest.MonkeyPatch, mod, *, head=None, get=None) -> None:
def fake_make_request(method, url, **kwargs):
if method == "HEAD" and head is not None:
return head(url, **kwargs)
if method == "GET" and get is not None:
return get(url, **kwargs)
raise AssertionError(f"unexpected remote fetcher method: {method}")
monkeypatch.setattr(mod.remote_fetcher, "make_request", fake_make_request)
def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types):
# HEAD 200 but content-type not supported and not text/html
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
@ -58,7 +69,7 @@ def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_
headers={"Content-Type": "image/png"}, # not supported
)
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
_patch_remote_fetcher(monkeypatch, stub_support_types, head=fake_head)
result = get_url("https://x.test/file.png")
assert result == "Unsupported content-type [image/png] of URL."
@ -82,7 +93,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytes
assert return_text is True
return "PDF extracted text"
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
_patch_remote_fetcher(monkeypatch, stub_support_types, head=fake_head)
monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
result = get_url("https://x.test/doc.pdf")
@ -103,8 +114,7 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.Monk
# chardet.detect returns utf-8
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
_patch_remote_fetcher(monkeypatch, mod, head=fake_head, get=fake_get)
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
@ -137,8 +147,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
_patch_remote_fetcher(monkeypatch, mod, head=fake_head, get=fake_get)
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
@ -150,7 +159,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
"""HEAD 403 → use cloudscraper.get via remote_fetcher.make_request, then proceed."""
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=403, headers={})
@ -167,7 +176,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
_patch_remote_fetcher(monkeypatch, mod, head=fake_head)
monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
@ -192,7 +201,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, st
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
_patch_remote_fetcher(monkeypatch, mod, head=fake_head)
out = get_url("https://x.test/fail")
assert out == "URL returned status code 500."
@ -214,7 +223,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.Monk
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
_patch_remote_fetcher(monkeypatch, mod, head=fake_head)
monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
out = get_url("https://x.test/fname")
@ -241,8 +250,7 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.Mo
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
_patch_remote_fetcher(monkeypatch, mod, head=fake_head, get=fake_get)
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)

View File

@ -452,6 +452,7 @@ class TestDifyNodeFactoryCreateNode:
factory._jinja2_template_renderer = sentinel.jinja2_template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._remote_file_http_client = sentinel.remote_file_http_client
factory._bound_tool_file_manager_factory = MagicMock(return_value=sentinel.tool_file_manager)
factory._file_reference_factory = sentinel.file_reference_factory
factory._prompt_message_serializer = sentinel.prompt_message_serializer
@ -596,7 +597,7 @@ class TestDifyNodeFactoryCreateNode:
factory._bound_tool_file_manager_factory.assert_called_once_with()
elif constructor_name == "DocumentExtractorNode":
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
assert kwargs["http_client"] is sentinel.http_client
assert kwargs["http_client"] is sentinel.remote_file_http_client
def test_build_llm_compatible_node_init_kwargs_preserves_structured_output_switch(self, factory):
node_data = LLMNodeData.model_validate(
@ -732,7 +733,7 @@ class TestDifyNodeFactoryCreateNode:
BuiltinNodeTypes.LLM,
"LLMNode",
{
"http_client": sentinel.http_client,
"http_client": sentinel.remote_file_http_client,
"llm_file_saver": sentinel.llm_file_saver,
"prompt_message_serializer": sentinel.prompt_message_serializer,
"retriever_attachment_loader": sentinel.retriever_attachment_loader,
@ -743,7 +744,7 @@ class TestDifyNodeFactoryCreateNode:
BuiltinNodeTypes.QUESTION_CLASSIFIER,
"QuestionClassifierNode",
{
"http_client": sentinel.http_client,
"http_client": sentinel.remote_file_http_client,
"llm_file_saver": sentinel.llm_file_saver,
"prompt_message_serializer": sentinel.prompt_message_serializer,
"template_renderer": sentinel.jinja2_template_renderer,

View File

@ -22,13 +22,13 @@ from graphon.variables.variables import StringVariable
def _mock_ssrf_head(monkeypatch: pytest.MonkeyPatch):
"""Avoid any real network requests during tests.
factories.file_factory.remote.get_remote_file_info() uses ssrf_proxy.head
to inspect
remote files. We stub it to return a minimal response object with
factories.file_factory.remote.get_remote_file_info() uses remote_fetcher.make_request
to inspect remote files. We stub it to return a minimal response object with
headers so filename/mime/size can be derived deterministically.
"""
def fake_head(url, *args, **kwargs):
def fake_head(method, url, *args, **kwargs):
assert method == "HEAD"
# choose a content-type by file suffix for determinism
if url.endswith(".pdf"):
ctype = "application/pdf"
@ -46,7 +46,7 @@ def _mock_ssrf_head(monkeypatch: pytest.MonkeyPatch):
}
return SimpleNamespace(status_code=200, headers=headers)
monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head)
monkeypatch.setattr("factories.file_factory.remote.remote_fetcher.make_request", fake_head)
class TestWorkflowEntry:

View File

@ -99,7 +99,7 @@ def mock_http_head():
},
)
with patch("factories.file_factory.remote.ssrf_proxy.head", autospec=True) as mock_head:
with patch("factories.file_factory.remote.remote_fetcher.make_request", autospec=True) as mock_head:
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
yield mock_head

View File

@ -15,10 +15,11 @@ class _FakeResponse:
def _mock_head(monkeypatch: pytest.MonkeyPatch, headers: dict[str, str], status_code: int = 200):
def _fake_head(url: str, follow_redirects: bool = True):
def _fake_head(method: str, url: str, follow_redirects: bool = True):
assert method == "HEAD"
return _FakeResponse(status_code=status_code, headers=headers)
monkeypatch.setattr("factories.file_factory.remote.ssrf_proxy.head", _fake_head)
monkeypatch.setattr("factories.file_factory.remote.remote_fetcher.make_request", _fake_head)
class TestGetRemoteFileInfo:

View File

@ -206,7 +206,10 @@ def test_export_rag_pipeline_dsl_raises_when_dataset_missing() -> None:
def test_import_rag_pipeline_url_fetch_error(mocker) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", side_effect=Exception("fetch failed"))
mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.make_request",
side_effect=Exception("fetch failed"),
)
service = RagPipelineDslService(session=Mock())
account = Mock(current_tenant_id="t1")
@ -813,7 +816,10 @@ def test_import_rag_pipeline_yaml_url_handles_empty_content_after_github_rewrite
response = Mock()
response.raise_for_status.return_value = None
response.content = b""
get_mock = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", return_value=response)
get_mock = mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.make_request",
return_value=response,
)
service = RagPipelineDslService(session=Mock())
account = Mock(current_tenant_id="t1")
@ -825,7 +831,7 @@ def test_import_rag_pipeline_yaml_url_handles_empty_content_after_github_rewrite
assert result.status == ImportStatus.FAILED
assert "Empty content from url" in result.error
called_url = get_mock.call_args.args[0]
called_url = get_mock.call_args.args[1]
assert "raw.githubusercontent.com" in called_url
@ -880,7 +886,7 @@ def test_import_rag_pipeline_url_size_exceeds_limit(mocker) -> None:
response = Mock()
response.raise_for_status.return_value = None
response.content = b"x" * (10 * 1024 * 1024 + 1)
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", return_value=response)
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.make_request", return_value=response)
service = RagPipelineDslService(session=Mock())
account = Mock(current_tenant_id="t1")