This commit is contained in:
Yansong Zhang 2025-12-15 10:14:29 +08:00
parent 790ed0845e
commit 3f7d46358c
1 changed files with 1 additions and 143 deletions

View File

@ -3,8 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request
from core.tools.errors import ToolSSRFError
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
@patch("httpx.Client.request")
@ -51,144 +50,3 @@ def test_retry_logic_success(mock_request):
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get("method") == "GET"
@patch("httpx.Client.request")
def test_squid_ssrf_rejection_detected(mock_request):
"""Test that Squid SSRF rejection (403) is converted to ToolSSRFError."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"server": "squid/5.2", "via": "1.1 squid"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://192.168.1.1/api")
assert "blocked by SSRF protection" in str(exc_info.value)
assert "192.168.1.1" in str(exc_info.value)
assert "squid.conf.template" in str(exc_info.value)
@patch("httpx.Client.request")
def test_squid_ssrf_rejection_via_header(mock_request):
"""Test detection via Via header when Server header is not present."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"via": "1.1 squid-proxy (squid/5.2)"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://10.0.0.1/api")
assert "SSRF protection" in str(exc_info.value)
@patch("httpx.Client.request")
def test_squid_401_rejection_detected(mock_request):
"""Test that Squid SSRF rejection with 401 is also converted to ToolSSRFError."""
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.headers = {"server": "squid/5.2"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://192.168.1.1/api")
assert "SSRF protection" in str(exc_info.value)
assert "squid.conf.template" in str(exc_info.value)
@patch("httpx.Client.request")
def test_regular_403_not_treated_as_ssrf(mock_request):
"""Test that regular 403 responses (not from Squid) are returned normally."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"server": "nginx/1.21.0"} # Not Squid
mock_request.return_value = mock_response
# Should not raise ToolSSRFError
response = make_request("GET", "http://example.com/api")
assert response.status_code == 403
class TestIsPrivateOrLocalAddress:
"""Test cases for SSRF protection function."""
def test_localhost_variants(self):
"""Test that localhost variants are detected as private."""
assert is_private_or_local_address("http://localhost/api") is True
assert is_private_or_local_address("http://127.0.0.1/api") is True
assert is_private_or_local_address("http://[::1]/api") is True
assert is_private_or_local_address("https://localhost:8080/") is True
def test_private_ipv4_ranges(self):
"""Test that private IPv4 ranges are detected."""
# 10.0.0.0/8
assert is_private_or_local_address("http://10.0.0.1/api") is True
assert is_private_or_local_address("http://10.255.255.255/api") is True
# 172.16.0.0/12
assert is_private_or_local_address("http://172.16.0.1/api") is True
assert is_private_or_local_address("http://172.31.255.255/api") is True
# 192.168.0.0/16
assert is_private_or_local_address("http://192.168.0.1/api") is True
assert is_private_or_local_address("http://192.168.255.255/api") is True
# 169.254.0.0/16 (link-local)
assert is_private_or_local_address("http://169.254.1.1/api") is True
def test_local_domains(self):
"""Test that .local domains are detected as private."""
assert is_private_or_local_address("http://myserver.local/api") is True
assert is_private_or_local_address("https://test.local:8080/") is True
def test_public_addresses(self):
"""Test that public addresses are not detected as private."""
assert is_private_or_local_address("http://example.com/api") is False
assert is_private_or_local_address("https://api.openai.com/v1") is False
assert is_private_or_local_address("http://8.8.8.8/") is False
assert is_private_or_local_address("https://1.1.1.1/") is False
assert is_private_or_local_address("http://93.184.216.34/") is False
def test_edge_cases(self):
"""Test edge cases and invalid inputs."""
# Empty or None
assert is_private_or_local_address("") is False
assert is_private_or_local_address(None) is False
# Invalid URLs
assert is_private_or_local_address("not-a-url") is False
assert is_private_or_local_address("://invalid") is False
def test_ipv6_private_ranges(self):
"""Test that private IPv6 ranges are detected."""
# IPv6 loopback
assert is_private_or_local_address("http://[::1]/api") is True
# IPv6 link-local (fe80::/10)
assert is_private_or_local_address("http://[fe80::1]/api") is True
# IPv6 unique local (fc00::/7)
assert is_private_or_local_address("http://[fc00::1]/api") is True
assert is_private_or_local_address("http://[fd00::1]/api") is True
def test_public_ipv6(self):
"""Test that public IPv6 addresses are not detected as private."""
# Public IPv6 addresses (real examples)
# Google Public DNS IPv6
assert is_private_or_local_address("http://[2001:4860:4860::8888]/api") is False
# Cloudflare DNS IPv6
assert is_private_or_local_address("http://[2606:4700:4700::1111]/api") is False
def test_url_with_ports(self):
"""Test URLs with custom ports."""
assert is_private_or_local_address("http://localhost:8080/api") is True
assert is_private_or_local_address("http://192.168.1.1:3000/") is True
assert is_private_or_local_address("https://example.com:443/api") is False
def test_url_schemes(self):
"""Test different URL schemes."""
assert is_private_or_local_address("https://127.0.0.1/api") is True
assert is_private_or_local_address("http://127.0.0.1/api") is True
assert is_private_or_local_address("https://example.com/api") is False