use squid for ssrf

This commit is contained in:
Yansong Zhang 2025-12-12 17:07:23 +08:00
parent f6accd8ae2
commit 92fa87e729
4 changed files with 73 additions and 333 deletions

View File

@ -2,79 +2,17 @@
Proxy requests to avoid SSRF Proxy requests to avoid SSRF
""" """
import ipaddress
import logging import logging
import time import time
from urllib.parse import urlparse
import httpx import httpx
from configs import dify_config from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_private_or_local_address(url: str) -> bool:
"""
Check if URL points to a private/local network address (SSRF protection).
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
by detecting private IP addresses, localhost, and local network domains.
Args:
url: The URL string to check
Returns:
True if the URL points to a private/local address, False otherwise
Examples:
>>> is_private_or_local_address("http://localhost/api")
True
>>> is_private_or_local_address("http://192.168.1.1/api")
True
>>> is_private_or_local_address("https://example.com/api")
False
"""
if not url:
return False
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
hostname_lower = hostname.lower()
# Check for localhost variants
if hostname_lower in ("localhost", "127.0.0.1", "::1"):
return True
# Check for .local domains (link-local)
if hostname_lower.endswith(".local"):
return True
# Try to parse as IP address
try:
ip = ipaddress.ip_address(hostname)
# Check if it's a private, loopback, or link-local address.
# - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7
# - Loopback: 127.0.0.0/8, ::1
# - Link-local: 169.254.0.0/16, fe80::/10
return ip.is_private or ip.is_loopback or ip.is_link_local
except ValueError:
# Not a valid IP address, might be a domain name
# Domain names could resolve to private IPs, but we only check the literal hostname here
# For more thorough checks, DNS resolution would be needed (but adds latency)
return False
except (ValueError, TypeError, AttributeError):
return False
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5 BACKOFF_FACTOR = 0.5
@ -156,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries: while retries <= max_retries:
try: try:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
# Squid typically identifies itself in Server or Via headers
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:
return response return response

View File

@ -8,33 +8,13 @@ import httpx
from flask import request from flask import request
from yaml import YAMLError, safe_load from yaml import YAMLError, safe_load
from core.helper.ssrf_proxy import is_private_or_local_address
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser: class ApiBasedToolSchemaParser:
@staticmethod
def _validate_server_urls(servers: list[dict]) -> None:
"""
Validate server URLs to prevent SSRF attacks.
Args:
servers: List of server dictionaries containing 'url' keys
Raises:
ToolSSRFError: If any server URL points to a private or local network address
"""
for server in servers:
server_url = server.get("url", "")
if server_url and is_private_or_local_address(server_url):
raise ToolSSRFError(
f"Server URL '{server_url}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
@staticmethod @staticmethod
def parse_openapi_to_tool_bundle( def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None openapi: dict, extra_info: dict | None = None, warning: dict | None = None
@ -48,9 +28,6 @@ class ApiBasedToolSchemaParser:
if len(openapi["servers"]) == 0: if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.") raise ToolProviderNotFoundError("No server found in the openapi yaml.")
# SSRF Protection: Validate all server URLs before processing
ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"])
server_url = openapi["servers"][0]["url"] server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env") request_env = request.headers.get("X-Request-Env")
if request_env: if request_env:
@ -310,9 +287,6 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0: if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.") raise ToolApiSchemaError("No server found in the swagger yaml.")
# SSRF Protection: Validate all server URLs before processing
ApiBasedToolSchemaParser._validate_server_urls(servers)
converted_openapi: dict[str, Any] = { converted_openapi: dict[str, Any] = {
"openapi": "3.0.0", "openapi": "3.0.0",
"info": { "info": {
@ -386,13 +360,6 @@ class ApiBasedToolSchemaParser:
if api_type != "openapi": if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.") raise ToolNotSupportedError("Only openapi is supported now.")
# SSRF Protection: Validate API URL before making HTTP request
if is_private_or_local_address(api_url):
raise ToolSSRFError(
f"API URL '{api_url}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
# get openapi yaml # get openapi yaml
response = httpx.get( response = httpx.get(
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
@ -457,8 +424,6 @@ class ApiBasedToolSchemaParser:
return openapi, schema_type return openapi, schema_type
except ToolApiSchemaError as e: except ToolApiSchemaError as e:
openapi_error = e openapi_error = e
except ToolSSRFError:
raise
# openapi parse error, fallback to swagger # openapi parse error, fallback to swagger
try: try:
@ -471,18 +436,12 @@ class ApiBasedToolSchemaParser:
), schema_type ), schema_type
except ToolApiSchemaError as e: except ToolApiSchemaError as e:
swagger_error = e swagger_error = e
except ToolSSRFError:
# SSRF protection errors should be raised immediately, don't fallback
raise
# swagger parse error, fallback to openai plugin # swagger parse error, fallback to openai plugin
try: try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning json_dumps(loaded_content), extra_info=extra_info, warning=warning
) )
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolSSRFError:
# SSRF protection errors should be raised immediately, don't fallback
raise
except ToolNotSupportedError as e: except ToolNotSupportedError as e:
# maybe it's not plugin at all # maybe it's not plugin at all
openapi_plugin_error = e openapi_plugin_error = e

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request
from core.tools.errors import ToolSSRFError
@patch("httpx.Client.request") @patch("httpx.Client.request")
@ -52,6 +53,64 @@ def test_retry_logic_success(mock_request):
assert mock_request.call_args_list[0][1].get("method") == "GET" assert mock_request.call_args_list[0][1].get("method") == "GET"
@patch("httpx.Client.request")
def test_squid_ssrf_rejection_detected(mock_request):
"""Test that Squid SSRF rejection (403) is converted to ToolSSRFError."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"server": "squid/5.2", "via": "1.1 squid"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://192.168.1.1/api")
assert "blocked by SSRF protection" in str(exc_info.value)
assert "192.168.1.1" in str(exc_info.value)
assert "squid.conf.template" in str(exc_info.value)
@patch("httpx.Client.request")
def test_squid_ssrf_rejection_via_header(mock_request):
"""Test detection via Via header when Server header is not present."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"via": "1.1 squid-proxy (squid/5.2)"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://10.0.0.1/api")
assert "SSRF protection" in str(exc_info.value)
@patch("httpx.Client.request")
def test_squid_401_rejection_detected(mock_request):
"""Test that Squid SSRF rejection with 401 is also converted to ToolSSRFError."""
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.headers = {"server": "squid/5.2"}
mock_request.return_value = mock_response
with pytest.raises(ToolSSRFError) as exc_info:
make_request("GET", "http://192.168.1.1/api")
assert "SSRF protection" in str(exc_info.value)
assert "squid.conf.template" in str(exc_info.value)
@patch("httpx.Client.request")
def test_regular_403_not_treated_as_ssrf(mock_request):
"""Test that regular 403 responses (not from Squid) are returned normally."""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.headers = {"server": "nginx/1.21.0"} # Not Squid
mock_request.return_value = mock_response
# Should not raise ToolSSRFError
response = make_request("GET", "http://example.com/api")
assert response.status_code == 403
class TestIsPrivateOrLocalAddress: class TestIsPrivateOrLocalAddress:
"""Test cases for SSRF protection function.""" """Test cases for SSRF protection function."""

View File

@ -1,228 +0,0 @@
"""Unit tests for SSRF protection in API schema parser."""
import pytest
from flask import Flask
from core.tools.errors import ToolSSRFError
from core.tools.utils.parser import ApiBasedToolSchemaParser
@pytest.fixture
def flask_app():
"""Create a Flask app for testing."""
app = Flask(__name__)
return app
class TestApiSchemaParserSSRF:
"""Test SSRF protection in API schema parser."""
def test_openapi_with_private_ip_blocked(self, flask_app):
"""Test that OpenAPI schema with private IP is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://192.168.1.1/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "192.168.1.1" in str(exc_info.value)
assert "private or local network address" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_localhost_blocked(self, flask_app):
"""Test that OpenAPI schema with localhost is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://localhost:8080/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "localhost" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_local_domain_blocked(self, flask_app):
"""Test that OpenAPI schema with .local domain is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://myserver.local/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "myserver.local" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_10_network_blocked(self, flask_app):
"""Test that OpenAPI schema with 10.x.x.x network is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://10.0.0.5/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "10.0.0.5" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_public_url_allowed(self, flask_app):
"""Test that OpenAPI schema with public URL is allowed."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: https://api.example.com/v1
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
# Should not raise any exception
result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert result is not None
assert len(result) > 0
def test_swagger_with_private_ip_blocked(self, flask_app):
"""Test that Swagger schema with private IP is blocked."""
swagger_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: http://172.16.0.1/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema)
assert "172.16.0.1" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_with_multiple_servers_one_private(self, flask_app):
"""Test that OpenAPI with multiple servers including one private is blocked."""
openapi_schema = """
openapi: 3.0.0
info:
title: Test API
version: 1.0.0
servers:
- url: https://api.example.com/v1
- url: http://192.168.1.100/api
paths:
/test:
get:
summary: Test endpoint
operationId: testGet
responses:
'200':
description: Success
"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema)
assert "192.168.1.100" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)
def test_openapi_json_format_with_private_ip_blocked(self, flask_app):
"""Test that JSON format OpenAPI schema with private IP is blocked."""
openapi_json = """{
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0"
},
"servers": [
{
"url": "http://127.0.0.1:8080/api"
}
],
"paths": {
"/test": {
"get": {
"summary": "Test endpoint",
"operationId": "testGet",
"responses": {
"200": {
"description": "Success"
}
}
}
}
}
}"""
with flask_app.test_request_context():
with pytest.raises(ToolSSRFError) as exc_info:
ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json)
assert "127.0.0.1" in str(exc_info.value)
assert "SSRF protection" in str(exc_info.value)