use squid for ssrf

This commit is contained in:
Yansong Zhang 2025-12-12 17:07:23 +08:00
parent f6accd8ae2
commit 92fa87e729
4 changed files with 73 additions and 333 deletions

View File

@ -2,79 +2,17 @@
Proxy requests to avoid SSRF
"""
import ipaddress
import logging
import time
from urllib.parse import urlparse
import httpx
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
logger = logging.getLogger(__name__)
def is_private_or_local_address(url: str) -> bool:
"""
Check if URL points to a private/local network address (SSRF protection).
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
by detecting private IP addresses, localhost, and local network domains.
Args:
url: The URL string to check
Returns:
True if the URL points to a private/local address, False otherwise
Examples:
>>> is_private_or_local_address("http://localhost/api")
True
>>> is_private_or_local_address("http://192.168.1.1/api")
True
>>> is_private_or_local_address("https://example.com/api")
False
"""
if not url:
return False
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
hostname_lower = hostname.lower()
# Check for localhost variants
if hostname_lower in ("localhost", "127.0.0.1", "::1"):
return True
# Check for .local domains (link-local)
if hostname_lower.endswith(".local"):
return True
# Try to parse as IP address
try:
ip = ipaddress.ip_address(hostname)
# Check if it's a private, loopback, or link-local address.
# - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7
# - Loopback: 127.0.0.0/8, ::1
# - Link-local: 169.254.0.0/16, fe80::/10
return ip.is_private or ip.is_loopback or ip.is_link_local
except ValueError:
# Not a valid IP address, might be a domain name
# Domain names could resolve to private IPs, but we only check the literal hostname here
# For more thorough checks, DNS resolution would be needed (but adds latency)
return False
except (ValueError, TypeError, AttributeError):
return False
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
@ -156,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
# Squid typically identifies itself in Server or Via headers
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
if response.status_code not in STATUS_FORCELIST:
return response

View File

@ -8,33 +8,13 @@ import httpx
from flask import request
from yaml import YAMLError, safe_load
from core.helper.ssrf_proxy import is_private_or_local_address
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def _validate_server_urls(servers: list[dict]) -> None:
"""
Validate server URLs to prevent SSRF attacks.
Args:
servers: List of server dictionaries containing 'url' keys
Raises:
ToolSSRFError: If any server URL points to a private or local network address
"""
for server in servers:
server_url = server.get("url", "")
if server_url and is_private_or_local_address(server_url):
raise ToolSSRFError(
f"Server URL '{server_url}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
@ -48,9 +28,6 @@ class ApiBasedToolSchemaParser:
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
# SSRF Protection: Validate all server URLs before processing
ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"])
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
@ -310,9 +287,6 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
# SSRF Protection: Validate all server URLs before processing
ApiBasedToolSchemaParser._validate_server_urls(servers)
converted_openapi: dict[str, Any] = {
"openapi": "3.0.0",
"info": {
@ -386,13 +360,6 @@ class ApiBasedToolSchemaParser:
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# SSRF Protection: Validate API URL before making HTTP request
if is_private_or_local_address(api_url):
raise ToolSSRFError(
f"API URL '{api_url}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
# get openapi yaml
response = httpx.get(
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
@ -457,8 +424,6 @@ class ApiBasedToolSchemaParser:
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
except ToolSSRFError:
raise
# openapi parse error, fallback to swagger
try:
@ -471,18 +436,12 @@ class ApiBasedToolSchemaParser:
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
except ToolSSRFError:
# SSRF protection errors should be raised immediately, don't fallback
raise
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolSSRFError:
# SSRF protection errors should be raised immediately, don't fallback
raise
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request
from core.tools.errors import ToolSSRFError
@patch("httpx.Client.request")
@ -52,6 +53,64 @@ def test_retry_logic_success(mock_request):
assert mock_request.call_args_list[0][1].get("method") == "GET"
@patch("httpx.Client.request")
def test_squid_ssrf_rejection_detected(mock_request):
"""Test that Squid SSRF rejection (403) is converted to ToolSSRFError."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"server": "squid/5.2", "via": "1.1 squid"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://192.168.1.1/api")
assert "blocked by SSRF protection" in str(exc_info.value)
assert "192.168.1.1" in str(exc_info.value)
assert "squid.conf.template" in str(exc_info.value)
@patch("httpx.Client.request")
def test_squid_ssrf_rejection_via_header(mock_request):
"""Test detection via Via header when Server header is not present."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"via": "1.1 squid-proxy (squid/5.2)"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://10.0.0.1/api")
assert "SSRF protection" in str(exc_info.value)
@patch("httpx.Client.request")
def test_squid_401_rejection_detected(mock_request):
"""Test that Squid SSRF rejection with 401 is also converted to ToolSSRFError."""
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.headers = {"server": "squid/5.2"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://192.168.1.1/api")
assert "SSRF protection" in str(exc_info.value)
assert "squid.conf.template" in str(exc_info.value)
@patch("httpx.Client.request")
def test_regular_403_not_treated_as_ssrf(mock_request):
"""Test that regular 403 responses (not from Squid) are returned normally."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"server": "nginx/1.21.0"} # Not Squid
mock_request.return_value = mock_response
# Should not raise ToolSSRFError
response = make_request("GET", "http://example.com/api")
assert response.status_code == 403
class TestIsPrivateOrLocalAddress:
"""Test cases for SSRF protection function."""

View File

@ -1,228 +0,0 @@
"""Unit tests for SSRF protection in API schema parser."""
import pytest
from flask import Flask
from core.tools.errors import ToolSSRFError
from core.tools.utils.parser import ApiBasedToolSchemaParser
@pytest.fixture
def flask_app():
"""Create a Flask app for testing."""
app = Flask(__name__)
return app
class TestApiSchemaParserSSRF:
"""Test SSRF protection in API schema parser."""
def test_openapi_with_private_ip_blocked(self, flask_app):
"""Test that OpenAPI schema with private IP is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://192.168.1.1/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "192.168.1.1" in str(exc_info.value)
assert "private or local network address" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_localhost_blocked(self, flask_app):
"""Test that OpenAPI schema with localhost is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://localhost:8080/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "localhost" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_local_domain_blocked(self, flask_app):
"""Test that OpenAPI schema with .local domain is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://myserver.local/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "myserver.local" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_10_network_blocked(self, flask_app):
"""Test that OpenAPI schema with 10.x.x.x network is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://10.0.0.5/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "10.0.0.5" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_public_url_allowed(self, flask_app):
"""Test that OpenAPI schema with public URL is allowed."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: https://api.example.com/v1
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
# Should not raise any exception
result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert result is not None
assert len(result) > 0
def test_swagger_with_private_ip_blocked(self, flask_app):
"""Test that Swagger schema with private IP is blocked."""
swagger_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://172.16.0.1/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema)
assert "172.16.0.1" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_multiple_servers_one_private(self, flask_app):
"""Test that OpenAPI with multiple servers including one private is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: https://api.example.com/v1
- url: http://192.168.1.100/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "192.168.1.100" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_json_format_with_private_ip_blocked(self, flask_app):
"""Test that JSON format OpenAPI schema with private IP is blocked."""
openapi_json = """{
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0"
},
"servers": [
{
"url": "http://127.0.0.1:8080/api"
}
],
"paths": {
"/test": {
"get": {
"summary": "Test endpoint",
"operationId": "testGet",
"responses": {
"200": {
"description": "Success"
}
}
}
}
}
}"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json)
assert "127.0.0.1" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)