fix: fix use build_request lead unexpect param (#30095)

This commit is contained in:
wangxiaolei 2025-12-24 17:23:30 +08:00 committed by GitHub
parent eb73f9a9b9
commit 2f9d718997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 63 deletions

View File

@ -118,13 +118,11 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
request = client.build_request(method=method, url=url, **kwargs)
# If user explicitly provided a Host header, ensure it's preserved
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
request.headers["Host"] = user_provided_host
response = client.send(request)
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):

View File

@ -1,11 +1,9 @@
import secrets
from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
STATUS_FORCELIST,
_get_user_provided_host_header,
make_request,
)
@ -14,11 +12,10 @@ from core.helper.ssrf_proxy import (
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
@ -28,11 +25,10 @@ def test_successful_request(mock_get_client):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
with pytest.raises(Exception) as e:
@ -40,32 +36,6 @@ def test_retry_exceed_max_retries(mock_get_client):
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_logic_success(mock_get_client):
mock_client = MagicMock()
mock_request = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
side_effects = []
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = secrets.choice(STATUS_FORCELIST)
retry_response = MagicMock()
retry_response.status_code = status_code
side_effects.append(retry_response)
side_effects.append(mock_response)
mock_client.send.side_effect = side_effects
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_client.send.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_client.build_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
class TestGetUserProvidedHostHeader:
"""Tests for _get_user_provided_host_header function."""
@ -111,14 +81,12 @@ def test_host_header_preservation_without_user_header(mock_get_client):
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# build_request should be called without headers dict containing Host
mock_client.build_request.assert_called_once()
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@ -132,31 +100,10 @@ def test_host_header_preservation_with_user_header(mock_get_client):
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
custom_host = "custom.example.com:8080"
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
# Verify build_request was called
mock_client.build_request.assert_called_once()
# Verify the Host header was set on the request object
assert mock_request.headers.get("Host") == custom_host
mock_client.send.assert_called_once_with(mock_request)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@pytest.mark.parametrize("host_key", ["host", "HOST"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert mock_request.headers.get("Host") == "api.example.com"