From aea3a6f80c816aa67bae59150a1e49daebbf5e0c Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 23 Dec 2025 19:01:12 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20when=20use=20forward=20proxy=20with=20ht?= =?UTF-8?q?tpx,=20httpx=20will=20overwrite=20the=20use=20=E2=80=A6=20(#300?= =?UTF-8?q?29)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/helper/ssrf_proxy.py | 34 +++- .../unit_tests/core/helper/test_ssrf_proxy.py | 154 +++++++++++++++--- 2 files changed, 165 insertions(+), 23 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 6c98aea1be..f2172e4e2f 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client: ) +def _get_user_provided_host_header(headers: dict | None) -> str | None: + """ + Extract the user-provided Host header from the headers dict. + + This is needed because when using a forward proxy, httpx may override the Host header. + We preserve the user's explicit Host header to support virtual hosting and other use cases. + """ + if not headers: + return None + # Case-insensitive lookup for Host header + for key, value in headers.items(): + if key.lower() == "host": + return value + return None + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -90,10 +106,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) client = _get_ssrf_client(verify_option) + # Preserve user-provided Host header + # When using a forward proxy, httpx may override the Host header based on the URL. + # We extract and preserve any explicitly set Host header to support virtual hosting. + headers = kwargs.get("headers", {}) + user_provided_host = _get_user_provided_host_header(headers) + retries = 0 while retries <= max_retries: try: - response = client.request(method=method, url=url, **kwargs) + # Build the request manually to preserve the Host header + # httpx may override the Host header when using a proxy, so we use + # the request API to explicitly set headers before sending + request = client.build_request(method=method, url=url, **kwargs) + + # If user explicitly provided a Host header, ensure it's preserved + if user_provided_host is not None: + request.headers["Host"] = user_provided_host + + response = client.send(request) + # Check for SSRF protection by Squid proxy if response.status_code in (401, 403): # Check if this is a Squid SSRF rejection diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 37749f0c66..d5bc3283fe 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -3,50 +3,160 @@ from unittest.mock import MagicMock, patch import pytest -from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request +from core.helper.ssrf_proxy import ( + SSRF_DEFAULT_MAX_RETRIES, + STATUS_FORCELIST, + _get_user_provided_host_header, + make_request, +) -@patch("httpx.Client.request") -def test_successful_request(mock_request): +@patch("core.helper.ssrf_proxy._get_ssrf_client") +def test_successful_request(mock_get_client): + mock_client = MagicMock() + mock_request = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 - mock_request.return_value = mock_response + mock_client.send.return_value = mock_response + mock_client.build_request.return_value = mock_request + mock_get_client.return_value = mock_client response = make_request("GET", "http://example.com") assert response.status_code == 200 -@patch("httpx.Client.request") -def test_retry_exceed_max_retries(mock_request): +@patch("core.helper.ssrf_proxy._get_ssrf_client") +def test_retry_exceed_max_retries(mock_get_client): + mock_client = MagicMock() + mock_request = MagicMock() mock_response = MagicMock() mock_response.status_code = 500 - - side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES - mock_request.side_effect = side_effects + mock_client.send.return_value = mock_response + mock_client.build_request.return_value = mock_request + mock_get_client.return_value = mock_client with pytest.raises(Exception) as e: make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" -@patch("httpx.Client.request") -def test_retry_logic_success(mock_request): - side_effects = [] +@patch("core.helper.ssrf_proxy._get_ssrf_client") +def test_retry_logic_success(mock_get_client): + mock_client = MagicMock() + mock_request = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + side_effects = [] for _ in range(SSRF_DEFAULT_MAX_RETRIES): status_code = secrets.choice(STATUS_FORCELIST) - mock_response = MagicMock() - mock_response.status_code = status_code - side_effects.append(mock_response) + retry_response = MagicMock() + retry_response.status_code = status_code + side_effects.append(retry_response) - mock_response_200 = MagicMock() - mock_response_200.status_code = 200 - side_effects.append(mock_response_200) - - mock_request.side_effect = side_effects + side_effects.append(mock_response) + mock_client.send.side_effect = side_effects + mock_client.build_request.return_value = mock_request + mock_get_client.return_value = mock_client response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES) assert response.status_code == 200 - assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 - assert mock_request.call_args_list[0][1].get("method") == "GET" + assert mock_client.send.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 + assert mock_client.build_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 + + +class TestGetUserProvidedHostHeader: + """Tests for _get_user_provided_host_header function.""" + + def test_returns_none_when_headers_is_none(self): + assert _get_user_provided_host_header(None) is None + + def test_returns_none_when_headers_is_empty(self): + assert _get_user_provided_host_header({}) is None + + def test_returns_none_when_host_header_not_present(self): + headers = {"Content-Type": "application/json", "Authorization": "Bearer token"} + assert _get_user_provided_host_header(headers) is None + + def test_returns_host_header_lowercase(self): + headers = {"host": "example.com"} + assert _get_user_provided_host_header(headers) == "example.com" + + def test_returns_host_header_uppercase(self): + headers = {"HOST": "example.com"} + assert _get_user_provided_host_header(headers) == "example.com" + + def test_returns_host_header_mixed_case(self): + headers = {"HoSt": "example.com"} + assert _get_user_provided_host_header(headers) == "example.com" + + def test_returns_host_header_from_multiple_headers(self): + headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"} + assert _get_user_provided_host_header(headers) == "api.example.com" + + def test_returns_first_host_header_when_duplicates(self): + headers = {"host": "first.com", "Host": "second.com"} + # Should return the first one encountered (iteration order is preserved in dict) + result = _get_user_provided_host_header(headers) + assert result in ("first.com", "second.com") + + +@patch("core.helper.ssrf_proxy._get_ssrf_client") +def test_host_header_preservation_without_user_header(mock_get_client): + """Test that when no Host header is provided, the default behavior is maintained.""" + mock_client = MagicMock() + mock_request = MagicMock() + mock_request.headers = {} + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.send.return_value = mock_response + mock_client.build_request.return_value = mock_request + mock_get_client.return_value = mock_client + + response = make_request("GET", "http://example.com") + + assert response.status_code == 200 + # build_request should be called without headers dict containing Host + mock_client.build_request.assert_called_once() + # Host should not be set if not provided by user + assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None + + +@patch("core.helper.ssrf_proxy._get_ssrf_client") +def test_host_header_preservation_with_user_header(mock_get_client): + """Test that user-provided Host header is preserved in the request.""" + mock_client = MagicMock() + mock_request = MagicMock() + mock_request.headers = {} + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.send.return_value = mock_response + mock_client.build_request.return_value = mock_request + mock_get_client.return_value = mock_client + + custom_host = "custom.example.com:8080" + response = make_request("GET", "http://example.com", headers={"Host": custom_host}) + + assert response.status_code == 200 + # Verify build_request was called + mock_client.build_request.assert_called_once() + # Verify the Host header was set on the request object + assert mock_request.headers.get("Host") == custom_host + mock_client.send.assert_called_once_with(mock_request) + + +@patch("core.helper.ssrf_proxy._get_ssrf_client") +@pytest.mark.parametrize("host_key", ["host", "HOST"]) +def test_host_header_preservation_case_insensitive(mock_get_client, host_key): + """Test that Host header is preserved regardless of case.""" + mock_client = MagicMock() + mock_request = MagicMock() + mock_request.headers = {} + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.send.return_value = mock_response + mock_client.build_request.return_value = mock_request + mock_get_client.return_value = mock_client + response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"}) + assert mock_request.headers.get("Host") == "api.example.com"