dify/api/tests/integration_tests/ssrf_proxy/test_ssrf_proxy.py
-LAN- 42110a8217
test(ssrf_proxy): Add integration test for ssrf proxy
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 13:45:08 +08:00

424 lines
16 KiB
Python
Executable File

#!/usr/bin/env python3
"""
SSRF Proxy Test Suite
This script tests the SSRF proxy configuration to ensure it blocks
private networks while allowing public internet access.
"""
import argparse
import json
import os
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass
from enum import Enum
from typing import final
import yaml
# Color codes for terminal output
class Colors:
RED: str = "\033[0;31m"
GREEN: str = "\033[0;32m"
YELLOW: str = "\033[1;33m"
BLUE: str = "\033[0;34m"
NC: str = "\033[0m" # No Color
class TestResult(Enum):
PASSED = "passed"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class TestCase:
name: str
url: str
expected_blocked: bool
category: str
description: str = ""
@final
class SSRFProxyTester:
def __init__(self, proxy_host: str = "localhost", proxy_port: int = 3128, test_file: str | None = None):
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_url = f"http://{proxy_host}:{proxy_port}"
self.container_name = "ssrf-proxy-test"
self.image = "ubuntu/squid:latest"
self.results: list[dict[str, object]] = []
self.test_file = test_file or "test_cases.yaml"
def start_proxy_container(self) -> bool:
"""Start the SSRF proxy container"""
print(f"{Colors.YELLOW}Starting SSRF proxy container...{Colors.NC}")
# Stop and remove existing container if exists
_ = subprocess.run(["docker", "stop", self.container_name], capture_output=True, text=True)
_ = subprocess.run(["docker", "rm", self.container_name], capture_output=True, text=True)
# Get directories for mounting config files
script_dir = os.path.dirname(os.path.abspath(__file__))
# Docker config files are in docker/ssrf_proxy relative to project root
project_root = os.path.abspath(os.path.join(script_dir, "..", "..", "..", ".."))
docker_config_dir = os.path.join(project_root, "docker", "ssrf_proxy")
# Start container
cmd = [
"docker",
"run",
"-d",
"--name",
self.container_name,
"-p",
f"{self.proxy_port}:{self.proxy_port}",
"-p",
"8194:8194",
"-v",
f"{docker_config_dir}/squid.conf.template:/etc/squid/squid.conf.template:ro",
"-v",
f"{docker_config_dir}/docker-entrypoint.sh:/docker-entrypoint-mount.sh:ro",
"-e",
f"HTTP_PORT={self.proxy_port}",
"-e",
"COREDUMP_DIR=/var/spool/squid",
"-e",
"REVERSE_PROXY_PORT=8194",
"-e",
"SANDBOX_HOST=sandbox",
"-e",
"SANDBOX_PORT=8194",
"--entrypoint",
"sh",
self.image,
"-c",
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\\r$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", # noqa: E501
]
# Add conf.d mount if directory exists
conf_d_path = f"{docker_config_dir}/conf.d"
if os.path.exists(conf_d_path) and os.listdir(conf_d_path):
cmd.insert(-3, "-v")
cmd.insert(-3, f"{conf_d_path}:/etc/squid/conf.d:ro")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"{Colors.RED}Failed to start container: {result.stderr}{Colors.NC}")
return False
# Wait for proxy to start
print(f"{Colors.YELLOW}Waiting for proxy to start...{Colors.NC}")
time.sleep(5)
# Check if container is running
result = subprocess.run(
["docker", "ps", "--filter", f"name={self.container_name}"],
capture_output=True,
text=True,
)
if self.container_name not in result.stdout:
print(f"{Colors.RED}Container failed to start!{Colors.NC}")
logs = subprocess.run(["docker", "logs", self.container_name], capture_output=True, text=True)
print(logs.stdout)
return False
print(f"{Colors.GREEN}Proxy started successfully!{Colors.NC}\n")
return True
def stop_proxy_container(self):
"""Stop and remove the proxy container"""
_ = subprocess.run(["docker", "stop", self.container_name], capture_output=True, text=True)
_ = subprocess.run(["docker", "rm", self.container_name], capture_output=True, text=True)
def test_url(self, test_case: TestCase) -> TestResult:
"""Test a single URL through the proxy"""
# Configure proxy for urllib
proxy_handler = urllib.request.ProxyHandler({"http": self.proxy_url, "https": self.proxy_url})
opener = urllib.request.build_opener(proxy_handler)
try:
# Make request through proxy
request = urllib.request.Request(test_case.url)
with opener.open(request, timeout=5):
# If we got a response, the request was allowed
is_blocked = False
except urllib.error.HTTPError as e:
# HTTP errors like 403 from proxy mean blocked
if e.code in [403, 407]:
is_blocked = True
else:
# Other HTTP errors mean the request went through
is_blocked = False
except (urllib.error.URLError, OSError, TimeoutError):
# Connection errors mean blocked by proxy
is_blocked = True
except Exception as e:
# Unexpected error
print(f"{Colors.YELLOW}Warning: Unexpected error testing {test_case.url}: {e}{Colors.NC}")
return TestResult.SKIPPED
# Check if result matches expectation
if is_blocked == test_case.expected_blocked:
return TestResult.PASSED
else:
return TestResult.FAILED
def run_test(self, test_case: TestCase):
"""Run a single test and record result"""
result = self.test_url(test_case)
# Print result
if result == TestResult.PASSED:
symbol = f"{Colors.GREEN}{Colors.NC}"
elif result == TestResult.FAILED:
symbol = f"{Colors.RED}{Colors.NC}"
else:
symbol = f"{Colors.YELLOW}{Colors.NC}"
status = "blocked" if test_case.expected_blocked else "allowed"
print(f" {symbol} {test_case.name} (should be {status})")
# Record result
self.results.append(
{
"name": test_case.name,
"category": test_case.category,
"url": test_case.url,
"expected_blocked": test_case.expected_blocked,
"result": result.value,
"description": test_case.description,
}
)
def run_all_tests(self):
"""Run all test cases"""
test_cases = self.get_test_cases()
print("=" * 50)
print(" SSRF Proxy Test Suite")
print("=" * 50)
# Group tests by category
categories: dict[str, list[TestCase]] = {}
for test in test_cases:
if test.category not in categories:
categories[test.category] = []
categories[test.category].append(test)
# Run tests by category
for category, tests in categories.items():
print(f"\n{Colors.YELLOW}{category}:{Colors.NC}")
for test in tests:
self.run_test(test)
def load_test_cases_from_yaml(self, yaml_file: str = "test_cases.yaml") -> list[TestCase]:
"""Load test cases from YAML configuration file"""
try:
# Try to load from YAML file
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), yaml_file)
with open(yaml_path) as f:
config = yaml.safe_load(f) # pyright: ignore[reportAny]
test_cases: list[TestCase] = []
# Parse test categories and cases from YAML
test_categories = config.get("test_categories", {}) # pyright: ignore[reportAny]
for category_key, category_data in test_categories.items(): # pyright: ignore[reportAny]
category_name: str = str(category_data.get("name", category_key)) # pyright: ignore[reportAny]
test_cases_list = category_data.get("test_cases", []) # pyright: ignore[reportAny]
for test_data in test_cases_list: # pyright: ignore[reportAny]
test_case = TestCase(
name=str(test_data["name"]), # pyright: ignore[reportAny]
url=str(test_data["url"]), # pyright: ignore[reportAny]
expected_blocked=bool(test_data["expected_blocked"]), # pyright: ignore[reportAny]
category=category_name,
description=str(test_data.get("description", "")), # pyright: ignore[reportAny]
)
test_cases.append(test_case)
if test_cases:
print(f"{Colors.BLUE}Loaded {len(test_cases)} test cases from {yaml_file}{Colors.NC}")
return test_cases
else:
print(f"{Colors.YELLOW}No test cases found in {yaml_file}, using defaults{Colors.NC}")
return self.get_default_test_cases()
except FileNotFoundError:
print(f"{Colors.YELLOW}Test case file {yaml_file} not found, using defaults{Colors.NC}")
return self.get_default_test_cases()
except yaml.YAMLError as e:
print(f"{Colors.YELLOW}Error parsing {yaml_file}: {e}, using defaults{Colors.NC}")
return self.get_default_test_cases()
except Exception as e:
print(f"{Colors.YELLOW}Unexpected error loading {yaml_file}: {e}, using defaults{Colors.NC}")
return self.get_default_test_cases()
def get_default_test_cases(self) -> list[TestCase]:
"""Fallback test cases if YAML loading fails"""
return [
# Essential test cases as fallback
TestCase("Loopback", "http://127.0.0.1", True, "Private Networks", "IPv4 loopback"),
TestCase("Private Network", "http://192.168.1.1", True, "Private Networks", "RFC 1918"),
TestCase("AWS Metadata", "http://169.254.169.254", True, "Cloud Metadata", "AWS metadata"),
TestCase("Public Site", "http://example.com", False, "Public Internet", "Public website"),
TestCase("Port 8080", "http://example.com:8080", True, "Port Restrictions", "Non-standard port"),
]
def get_test_cases(self) -> list[TestCase]:
"""Get all test cases from YAML or defaults"""
return self.load_test_cases_from_yaml(self.test_file)
def print_summary(self):
"""Print test results summary"""
passed = sum(1 for r in self.results if r["result"] == "passed")
failed = sum(1 for r in self.results if r["result"] == "failed")
skipped = sum(1 for r in self.results if r["result"] == "skipped")
print("\n" + "=" * 50)
print(" Test Summary")
print("=" * 50)
print(f"Tests Passed: {Colors.GREEN}{passed}{Colors.NC}")
print(f"Tests Failed: {Colors.RED}{failed}{Colors.NC}")
if skipped > 0:
print(f"Tests Skipped: {Colors.YELLOW}{skipped}{Colors.NC}")
if failed == 0:
print(f"\n{Colors.GREEN}✓ All tests passed! SSRF proxy is configured correctly.{Colors.NC}")
else:
print(f"\n{Colors.RED}✗ Some tests failed. Please review the configuration.{Colors.NC}")
print("\nFailed tests:")
for r in self.results:
if r["result"] == "failed":
status = "should be blocked" if r["expected_blocked"] else "should be allowed"
print(f" - {r['name']} ({status}): {r['url']}")
return failed == 0
def save_results(self, filename: str = "test_results.json"):
"""Save test results to JSON file"""
with open(filename, "w") as f:
json.dump(
{
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"proxy_url": self.proxy_url,
"results": self.results,
},
f,
indent=2,
)
print(f"\nResults saved to {filename}")
def main():
@dataclass
class Args:
host: str = "localhost"
port: int = 3128
no_container: bool = False
save_results: bool = False
test_file: str | None = None
list_tests: bool = False
def parse_args() -> Args:
parser = argparse.ArgumentParser(description="Test SSRF Proxy Configuration")
_ = parser.add_argument("--host", type=str, default="localhost", help="Proxy host (default: localhost)")
_ = parser.add_argument("--port", type=int, default=3128, help="Proxy port (default: 3128)")
_ = parser.add_argument(
"--no-container",
action="store_true",
help="Don't start container (assume proxy is already running)",
)
_ = parser.add_argument("--save-results", action="store_true", help="Save test results to JSON file")
_ = parser.add_argument(
"--test-file", type=str, help="Path to YAML file containing test cases (default: test_cases.yaml)"
)
_ = parser.add_argument("--list-tests", action="store_true", help="List all test cases without running them")
# Parse arguments - argparse.Namespace has Any-typed attributes
# This is a known limitation of argparse in Python's type system
namespace = parser.parse_args()
# Convert namespace attributes to properly typed values
# argparse guarantees these attributes exist with the correct types
# based on our argument definitions, but the type system cannot verify this
return Args(
host=str(namespace.host), # pyright: ignore[reportAny]
port=int(namespace.port), # pyright: ignore[reportAny]
no_container=bool(namespace.no_container), # pyright: ignore[reportAny]
save_results=bool(namespace.save_results), # pyright: ignore[reportAny]
test_file=namespace.test_file if namespace.test_file else None, # pyright: ignore[reportAny]
list_tests=bool(namespace.list_tests), # pyright: ignore[reportAny]
)
args = parse_args()
tester = SSRFProxyTester(args.host, args.port, args.test_file)
# If --list-tests flag is set, just list the tests and exit
if args.list_tests:
test_cases = tester.get_test_cases()
print("\n" + "=" * 50)
print(" Available Test Cases")
print("=" * 50)
# Group by category for display
categories: dict[str, list[TestCase]] = {}
for test in test_cases:
if test.category not in categories:
categories[test.category] = []
categories[test.category].append(test)
for category, tests in categories.items():
print(f"\n{Colors.YELLOW}{category}:{Colors.NC}")
for test in tests:
blocked_status = "BLOCK" if test.expected_blocked else "ALLOW"
color = Colors.RED if test.expected_blocked else Colors.GREEN
print(f" {color}[{blocked_status}]{Colors.NC} {test.name}")
if test.description:
print(f" {test.description}")
print(f" URL: {test.url}")
print(f"\nTotal: {len(test_cases)} test cases")
sys.exit(0)
try:
# Start container unless --no-container flag is set
if not args.no_container:
if not tester.start_proxy_container():
sys.exit(1)
# Run tests
tester.run_all_tests()
# Print summary
success = tester.print_summary()
# Save results if requested
if args.save_results:
tester.save_results()
# Exit with appropriate code
sys.exit(0 if success else 1)
finally:
# Cleanup
if not args.no_container:
print(f"\n{Colors.YELLOW}Cleaning up...{Colors.NC}")
tester.stop_proxy_container()
if __name__ == "__main__":
main()