fix: when use forward proxy with httpx, httpx will overwrite the use … (#30029)

This commit is contained in:
wangxiaolei 2025-12-23 19:01:12 +08:00 committed by GitHub
parent 3f27b3f0b4
commit aea3a6f80c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 165 additions and 23 deletions

View File

@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
)
def _get_user_provided_host_header(headers: dict | None) -> str | None:
"""
Extract the user-provided Host header from the headers dict.
This is needed because when using a forward proxy, httpx may override the Host header.
We preserve the user's explicit Host header to support virtual hosting and other use cases.
"""
if not headers:
return None
# Case-insensitive lookup for Host header
for key, value in headers.items():
if key.lower() == "host":
return value
return None
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -90,10 +106,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
response = client.request(method=method, url=url, **kwargs)
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
request = client.build_request(method=method, url=url, **kwargs)
# If user explicitly provided a Host header, ensure it's preserved
if user_provided_host is not None:
request.headers["Host"] = user_provided_host
response = client.send(request)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection

View File

@ -3,50 +3,160 @@ from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
STATUS_FORCELIST,
_get_user_provided_host_header,
make_request,
)
@patch("httpx.Client.request")
def test_successful_request(mock_request):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch("httpx.Client.request")
def test_retry_exceed_max_retries(mock_request):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
mock_request.side_effect = side_effects
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
with pytest.raises(Exception) as e:
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("httpx.Client.request")
def test_retry_logic_success(mock_request):
side_effects = []
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_logic_success(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
side_effects = []
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = secrets.choice(STATUS_FORCELIST)
mock_response = MagicMock()
mock_response.status_code = status_code
side_effects.append(mock_response)
retry_response = MagicMock()
retry_response.status_code = status_code
side_effects.append(retry_response)
mock_response_200 = MagicMock()
mock_response_200.status_code = 200
side_effects.append(mock_response_200)
mock_request.side_effect = side_effects
side_effects.append(mock_response)
mock_client.send.side_effect = side_effects
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get("method") == "GET"
assert mock_client.send.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_client.build_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
class TestGetUserProvidedHostHeader:
"""Tests for _get_user_provided_host_header function."""
def test_returns_none_when_headers_is_none(self):
assert _get_user_provided_host_header(None) is None
def test_returns_none_when_headers_is_empty(self):
assert _get_user_provided_host_header({}) is None
def test_returns_none_when_host_header_not_present(self):
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
assert _get_user_provided_host_header(headers) is None
def test_returns_host_header_lowercase(self):
headers = {"host": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_uppercase(self):
headers = {"HOST": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_mixed_case(self):
headers = {"HoSt": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_from_multiple_headers(self):
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
assert _get_user_provided_host_header(headers) == "api.example.com"
def test_returns_first_host_header_when_duplicates(self):
headers = {"host": "first.com", "Host": "second.com"}
# Should return the first one encountered (iteration order is preserved in dict)
result = _get_user_provided_host_header(headers)
assert result in ("first.com", "second.com")
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_without_user_header(mock_get_client):
"""Test that when no Host header is provided, the default behavior is maintained."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# build_request should be called without headers dict containing Host
mock_client.build_request.assert_called_once()
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_with_user_header(mock_get_client):
"""Test that user-provided Host header is preserved in the request."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
custom_host = "custom.example.com:8080"
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
# Verify build_request was called
mock_client.build_request.assert_called_once()
# Verify the Host header was set on the request object
assert mock_request.headers.get("Host") == custom_host
mock_client.send.assert_called_once_with(mock_request)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@pytest.mark.parametrize("host_key", ["host", "HOST"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert mock_request.headers.get("Host") == "api.example.com"