From 3f3b9beeff3df0220987699873ee4817d0ffa955 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 12 Dec 2025 11:24:25 +0800 Subject: [PATCH 1/8] add internal ip filter when parse tool schema --- api/core/helper/ssrf_proxy.py | 72 ++++++ api/core/tools/errors.py | 4 + api/core/tools/utils/parser.py | 39 ++- .../unit_tests/core/helper/test_ssrf_proxy.py | 85 ++++++- .../core/tools/utils/test_parser_ssrf.py | 228 ++++++++++++++++++ 5 files changed, 424 insertions(+), 4 deletions(-) create mode 100644 api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..ba91ebd9d4 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,8 +2,10 @@ Proxy requests to avoid SSRF """ +import ipaddress import logging import time +from urllib.parse import urlparse import httpx @@ -12,6 +14,76 @@ from core.helper.http_client_pooling import get_pooled_http_client logger = logging.getLogger(__name__) + +def is_private_or_local_address(url: str) -> bool: + """ + Check if URL points to a private/local network address (SSRF protection). + + This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks + by detecting private IP addresses, localhost, and local network domains. + + Args: + url: The URL string to check + + Returns: + True if the URL points to a private/local address, False otherwise + + Examples: + >>> is_private_or_local_address("http://localhost/api") + True + >>> is_private_or_local_address("http://192.168.1.1/api") + True + >>> is_private_or_local_address("https://example.com/api") + False + """ + if not url: + return False + + try: + parsed = urlparse(url) + hostname = parsed.hostname + + if not hostname: + return False + + hostname_lower = hostname.lower() + + # Check for localhost variants + if hostname_lower in ("localhost", "127.0.0.1", "::1"): + return True + + # Check for .local domains (link-local) + if hostname_lower.endswith(".local"): + return True + + # Try to parse as IP address + try: + ip = ipaddress.ip_address(hostname) + + # Check if it's a private IP (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 for IPv4) + # For IPv6: fc00::/7 (unique local addresses) + if ip.is_private: + return True + + # Check if it's loopback (127.0.0.0/8 for IPv4, ::1 for IPv6) + if ip.is_loopback: + return True + + # Check if it's link-local (169.254.0.0/16 for IPv4, fe80::/10 for IPv6) + if ip.is_link_local: + return True + + return False + except ValueError: + # Not a valid IP address, might be a domain name + # Domain names could resolve to private IPs, but we only check the literal hostname here + # For more thorough checks, DNS resolution would be needed (but adds latency) + return False + + except (ValueError, TypeError, AttributeError): + return False + + SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..b2c7d3db80 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -8,10 +8,11 @@ import httpx from flask import request from yaml import YAMLError, safe_load +from core.helper.ssrf_proxy import is_private_or_local_address from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter -from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError class ApiBasedToolSchemaParser: @@ -28,6 +29,15 @@ class ApiBasedToolSchemaParser: if len(openapi["servers"]) == 0: raise ToolProviderNotFoundError("No server found in the openapi yaml.") + # SSRF Protection: Validate all server URLs before processing + for server in openapi["servers"]: + server_url_to_check = server.get("url", "") + if server_url_to_check and is_private_or_local_address(server_url_to_check): + raise ToolSSRFError( + f"Server URL '{server_url_to_check}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + server_url = openapi["servers"][0]["url"] request_env = request.headers.get("X-Request-Env") if request_env: @@ -287,6 +297,15 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") + # SSRF Protection: Validate all server URLs before processing + for server in servers: + server_url_to_check = server.get("url", "") + if server_url_to_check and is_private_or_local_address(server_url_to_check): + raise ToolSSRFError( + f"Server URL '{server_url_to_check}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + converted_openapi: dict[str, Any] = { "openapi": "3.0.0", "info": { @@ -360,6 +379,13 @@ class ApiBasedToolSchemaParser: if api_type != "openapi": raise ToolNotSupportedError("Only openapi is supported now.") + # SSRF Protection: Validate API URL before making HTTP request + if is_private_or_local_address(api_url): + raise ToolSSRFError( + f"API URL '{api_url}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + # get openapi yaml response = httpx.get( api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 @@ -424,8 +450,10 @@ class ApiBasedToolSchemaParser: return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e + except ToolSSRFError: + raise - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,13 +464,18 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - + except ToolSSRFError: + # SSRF protection errors should be raised immediately, don't fallback + raise # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( json_dumps(loaded_content), extra_info=extra_info, warning=warning ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN + except ToolSSRFError: + # SSRF protection errors should be raised immediately, don't fallback + raise except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 37749f0c66..caf92d5af6 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request @patch("httpx.Client.request") @@ -50,3 +50,86 @@ def test_retry_logic_success(mock_request): assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 assert mock_request.call_args_list[0][1].get("method") == "GET" + + +class TestIsPrivateOrLocalAddress: + """Test cases for SSRF protection function.""" + + def test_localhost_variants(self): + """Test that localhost variants are detected as private.""" + assert is_private_or_local_address("http://localhost/api") is True + assert is_private_or_local_address("http://127.0.0.1/api") is True + assert is_private_or_local_address("http://[::1]/api") is True + assert is_private_or_local_address("https://localhost:8080/") is True + + def test_private_ipv4_ranges(self): + """Test that private IPv4 ranges are detected.""" + # 10.0.0.0/8 + assert is_private_or_local_address("http://10.0.0.1/api") is True + assert is_private_or_local_address("http://10.255.255.255/api") is True + + # 172.16.0.0/12 + assert is_private_or_local_address("http://172.16.0.1/api") is True + assert is_private_or_local_address("http://172.31.255.255/api") is True + + # 192.168.0.0/16 + assert is_private_or_local_address("http://192.168.0.1/api") is True + assert is_private_or_local_address("http://192.168.255.255/api") is True + + # 169.254.0.0/16 (link-local) + assert is_private_or_local_address("http://169.254.1.1/api") is True + + def test_local_domains(self): + """Test that .local domains are detected as private.""" + assert is_private_or_local_address("http://myserver.local/api") is True + assert is_private_or_local_address("https://test.local:8080/") is True + + def test_public_addresses(self): + """Test that public addresses are not detected as private.""" + assert is_private_or_local_address("http://example.com/api") is False + assert is_private_or_local_address("https://api.openai.com/v1") is False + assert is_private_or_local_address("http://8.8.8.8/") is False + assert is_private_or_local_address("https://1.1.1.1/") is False + assert is_private_or_local_address("http://93.184.216.34/") is False + + def test_edge_cases(self): + """Test edge cases and invalid inputs.""" + # Empty or None + assert is_private_or_local_address("") is False + assert is_private_or_local_address(None) is False + + # Invalid URLs + assert is_private_or_local_address("not-a-url") is False + assert is_private_or_local_address("://invalid") is False + + def test_ipv6_private_ranges(self): + """Test that private IPv6 ranges are detected.""" + # IPv6 loopback + assert is_private_or_local_address("http://[::1]/api") is True + + # IPv6 link-local (fe80::/10) + assert is_private_or_local_address("http://[fe80::1]/api") is True + + # IPv6 unique local (fc00::/7) + assert is_private_or_local_address("http://[fc00::1]/api") is True + assert is_private_or_local_address("http://[fd00::1]/api") is True + + def test_public_ipv6(self): + """Test that public IPv6 addresses are not detected as private.""" + # Public IPv6 addresses (real examples) + # Google Public DNS IPv6 + assert is_private_or_local_address("http://[2001:4860:4860::8888]/api") is False + # Cloudflare DNS IPv6 + assert is_private_or_local_address("http://[2606:4700:4700::1111]/api") is False + + def test_url_with_ports(self): + """Test URLs with custom ports.""" + assert is_private_or_local_address("http://localhost:8080/api") is True + assert is_private_or_local_address("http://192.168.1.1:3000/") is True + assert is_private_or_local_address("https://example.com:443/api") is False + + def test_url_schemes(self): + """Test different URL schemes.""" + assert is_private_or_local_address("https://127.0.0.1/api") is True + assert is_private_or_local_address("http://127.0.0.1/api") is True + assert is_private_or_local_address("https://example.com/api") is False diff --git a/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py b/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py new file mode 100644 index 0000000000..da2d8e33fc --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py @@ -0,0 +1,228 @@ +"""Unit tests for SSRF protection in API schema parser.""" + +import pytest +from flask import Flask + +from core.tools.errors import ToolSSRFError +from core.tools.utils.parser import ApiBasedToolSchemaParser + + +@pytest.fixture +def flask_app(): + """Create a Flask app for testing.""" + app = Flask(__name__) + return app + + +class TestApiSchemaParserSSRF: + """Test SSRF protection in API schema parser.""" + + def test_openapi_with_private_ip_blocked(self, flask_app): + """Test that OpenAPI schema with private IP is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://192.168.1.1/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "192.168.1.1" in str(exc_info.value) + assert "private or local network address" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_localhost_blocked(self, flask_app): + """Test that OpenAPI schema with localhost is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://localhost:8080/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "localhost" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_local_domain_blocked(self, flask_app): + """Test that OpenAPI schema with .local domain is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://myserver.local/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "myserver.local" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_10_network_blocked(self, flask_app): + """Test that OpenAPI schema with 10.x.x.x network is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://10.0.0.5/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "10.0.0.5" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_public_url_allowed(self, flask_app): + """Test that OpenAPI schema with public URL is allowed.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com/v1 +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + # Should not raise any exception + result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + assert result is not None + assert len(result) > 0 + + def test_swagger_with_private_ip_blocked(self, flask_app): + """Test that Swagger schema with private IP is blocked.""" + swagger_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://172.16.0.1/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema) + + assert "172.16.0.1" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_multiple_servers_one_private(self, flask_app): + """Test that OpenAPI with multiple servers including one private is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com/v1 + - url: http://192.168.1.100/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "192.168.1.100" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_json_format_with_private_ip_blocked(self, flask_app): + """Test that JSON format OpenAPI schema with private IP is blocked.""" + openapi_json = """{ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://127.0.0.1:8080/api" + } + ], + "paths": { + "/test": { + "get": { + "summary": "Test endpoint", + "operationId": "testGet", + "responses": { + "200": { + "description": "Success" + } + } + } + } + } +}""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json) + + assert "127.0.0.1" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) From 64f5e3409627d4c34d3a4828e7ebad4d7d506382 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Fri, 12 Dec 2025 11:31:27 +0800 Subject: [PATCH 2/8] Update api/core/helper/ssrf_proxy.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/helper/ssrf_proxy.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ba91ebd9d4..691882d522 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -60,20 +60,11 @@ def is_private_or_local_address(url: str) -> bool: try: ip = ipaddress.ip_address(hostname) - # Check if it's a private IP (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 for IPv4) - # For IPv6: fc00::/7 (unique local addresses) - if ip.is_private: - return True - - # Check if it's loopback (127.0.0.0/8 for IPv4, ::1 for IPv6) - if ip.is_loopback: - return True - - # Check if it's link-local (169.254.0.0/16 for IPv4, fe80::/10 for IPv6) - if ip.is_link_local: - return True - - return False + # Check if it's a private, loopback, or link-local address. + # - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 + # - Loopback: 127.0.0.0/8, ::1 + # - Link-local: 169.254.0.0/16, fe80::/10 + return ip.is_private or ip.is_loopback or ip.is_link_local except ValueError: # Not a valid IP address, might be a domain name # Domain names could resolve to private IPs, but we only check the literal hostname here From e66ef9145ba56dda83d6b838c77c6d6107784e83 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 12 Dec 2025 11:37:30 +0800 Subject: [PATCH 3/8] add internal ip filter when parse tool schema --- api/core/tools/utils/parser.py | 35 ++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index b2c7d3db80..c75d5090f3 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -16,6 +16,25 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro class ApiBasedToolSchemaParser: + @staticmethod + def _validate_server_urls(servers: list[dict]) -> None: + """ + Validate server URLs to prevent SSRF attacks. + + Args: + servers: List of server dictionaries containing 'url' keys + + Raises: + ToolSSRFError: If any server URL points to a private or local network address + """ + for server in servers: + server_url = server.get("url", "") + if server_url and is_private_or_local_address(server_url): + raise ToolSSRFError( + f"Server URL '{server_url}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + @staticmethod def parse_openapi_to_tool_bundle( openapi: dict, extra_info: dict | None = None, warning: dict | None = None @@ -30,13 +49,7 @@ class ApiBasedToolSchemaParser: raise ToolProviderNotFoundError("No server found in the openapi yaml.") # SSRF Protection: Validate all server URLs before processing - for server in openapi["servers"]: - server_url_to_check = server.get("url", "") - if server_url_to_check and is_private_or_local_address(server_url_to_check): - raise ToolSSRFError( - f"Server URL '{server_url_to_check}' points to a private or local network address, " - "which is not allowed for security reasons (SSRF protection)." - ) + ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"]) server_url = openapi["servers"][0]["url"] request_env = request.headers.get("X-Request-Env") @@ -298,13 +311,7 @@ class ApiBasedToolSchemaParser: raise ToolApiSchemaError("No server found in the swagger yaml.") # SSRF Protection: Validate all server URLs before processing - for server in servers: - server_url_to_check = server.get("url", "") - if server_url_to_check and is_private_or_local_address(server_url_to_check): - raise ToolSSRFError( - f"Server URL '{server_url_to_check}' points to a private or local network address, " - "which is not allowed for security reasons (SSRF protection)." - ) + ApiBasedToolSchemaParser._validate_server_urls(servers) converted_openapi: dict[str, Any] = { "openapi": "3.0.0", From f6accd8ae2b2d39eb49d7cbdc86ecc41ab285b92 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 12 Dec 2025 11:58:05 +0800 Subject: [PATCH 4/8] add internal ip filter when parse tool schema --- api/tests/unit_tests/core/tools/utils/test_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index f39158aa59..8e3e486018 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -14,7 +14,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, - "servers": [{"url": "http://localhost:3000"}], + "servers": [{"url": "https://api.example.com"}], "paths": { "/": { "get": { @@ -60,7 +60,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, - "servers": [{"url": "http://localhost:3000"}], + "servers": [{"url": "https://api.example.com"}], "paths": { "/api/resource": { "get": { From 92fa87e729bfd0c9708e2a8f8dde0fb2690a3d1a Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 12 Dec 2025 17:07:23 +0800 Subject: [PATCH 5/8] use squid for ssrf --- api/core/helper/ssrf_proxy.py | 76 +----- api/core/tools/utils/parser.py | 43 +--- .../unit_tests/core/helper/test_ssrf_proxy.py | 59 +++++ .../core/tools/utils/test_parser_ssrf.py | 228 ------------------ 4 files changed, 73 insertions(+), 333 deletions(-) delete mode 100644 api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 691882d522..442b2145a5 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,79 +2,17 @@ Proxy requests to avoid SSRF """ -import ipaddress import logging import time -from urllib.parse import urlparse import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) - -def is_private_or_local_address(url: str) -> bool: - """ - Check if URL points to a private/local network address (SSRF protection). - - This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks - by detecting private IP addresses, localhost, and local network domains. - - Args: - url: The URL string to check - - Returns: - True if the URL points to a private/local address, False otherwise - - Examples: - >>> is_private_or_local_address("http://localhost/api") - True - >>> is_private_or_local_address("http://192.168.1.1/api") - True - >>> is_private_or_local_address("https://example.com/api") - False - """ - if not url: - return False - - try: - parsed = urlparse(url) - hostname = parsed.hostname - - if not hostname: - return False - - hostname_lower = hostname.lower() - - # Check for localhost variants - if hostname_lower in ("localhost", "127.0.0.1", "::1"): - return True - - # Check for .local domains (link-local) - if hostname_lower.endswith(".local"): - return True - - # Try to parse as IP address - try: - ip = ipaddress.ip_address(hostname) - - # Check if it's a private, loopback, or link-local address. - # - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 - # - Loopback: 127.0.0.0/8, ::1 - # - Link-local: 169.254.0.0/16, fe80::/10 - return ip.is_private or ip.is_loopback or ip.is_link_local - except ValueError: - # Not a valid IP address, might be a domain name - # Domain names could resolve to private IPs, but we only check the literal hostname here - # For more thorough checks, DNS resolution would be needed (but adds latency) - return False - - except (ValueError, TypeError, AttributeError): - return False - - SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 @@ -156,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index c75d5090f3..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -8,33 +8,13 @@ import httpx from flask import request from yaml import YAMLError, safe_load -from core.helper.ssrf_proxy import is_private_or_local_address from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter -from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError class ApiBasedToolSchemaParser: - @staticmethod - def _validate_server_urls(servers: list[dict]) -> None: - """ - Validate server URLs to prevent SSRF attacks. - - Args: - servers: List of server dictionaries containing 'url' keys - - Raises: - ToolSSRFError: If any server URL points to a private or local network address - """ - for server in servers: - server_url = server.get("url", "") - if server_url and is_private_or_local_address(server_url): - raise ToolSSRFError( - f"Server URL '{server_url}' points to a private or local network address, " - "which is not allowed for security reasons (SSRF protection)." - ) - @staticmethod def parse_openapi_to_tool_bundle( openapi: dict, extra_info: dict | None = None, warning: dict | None = None @@ -48,9 +28,6 @@ class ApiBasedToolSchemaParser: if len(openapi["servers"]) == 0: raise ToolProviderNotFoundError("No server found in the openapi yaml.") - # SSRF Protection: Validate all server URLs before processing - ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"]) - server_url = openapi["servers"][0]["url"] request_env = request.headers.get("X-Request-Env") if request_env: @@ -310,9 +287,6 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - # SSRF Protection: Validate all server URLs before processing - ApiBasedToolSchemaParser._validate_server_urls(servers) - converted_openapi: dict[str, Any] = { "openapi": "3.0.0", "info": { @@ -386,13 +360,6 @@ class ApiBasedToolSchemaParser: if api_type != "openapi": raise ToolNotSupportedError("Only openapi is supported now.") - # SSRF Protection: Validate API URL before making HTTP request - if is_private_or_local_address(api_url): - raise ToolSSRFError( - f"API URL '{api_url}' points to a private or local network address, " - "which is not allowed for security reasons (SSRF protection)." - ) - # get openapi yaml response = httpx.get( api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 @@ -457,8 +424,6 @@ class ApiBasedToolSchemaParser: return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - except ToolSSRFError: - raise # openapi parse error, fallback to swagger try: @@ -471,18 +436,12 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - except ToolSSRFError: - # SSRF protection errors should be raised immediately, don't fallback - raise # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( json_dumps(loaded_content), extra_info=extra_info, warning=warning ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN - except ToolSSRFError: - # SSRF protection errors should be raised immediately, don't fallback - raise except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index caf92d5af6..e99bc93c67 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request +from core.tools.errors import ToolSSRFError @patch("httpx.Client.request") @@ -52,6 +53,64 @@ def test_retry_logic_success(mock_request): assert mock_request.call_args_list[0][1].get("method") == "GET" +@patch("httpx.Client.request") +def test_squid_ssrf_rejection_detected(mock_request): + """Test that Squid SSRF rejection (403) is converted to ToolSSRFError.""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.headers = {"server": "squid/5.2", "via": "1.1 squid"} + mock_request.return_value = mock_response + + with pytest.raises(ToolSSRFError) as exc_info: + make_request("GET", "http://192.168.1.1/api") + + assert "blocked by SSRF protection" in str(exc_info.value) + assert "192.168.1.1" in str(exc_info.value) + assert "squid.conf.template" in str(exc_info.value) + + +@patch("httpx.Client.request") +def test_squid_ssrf_rejection_via_header(mock_request): + """Test detection via Via header when Server header is not present.""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.headers = {"via": "1.1 squid-proxy (squid/5.2)"} + mock_request.return_value = mock_response + + with pytest.raises(ToolSSRFError) as exc_info: + make_request("GET", "http://10.0.0.1/api") + + assert "SSRF protection" in str(exc_info.value) + + +@patch("httpx.Client.request") +def test_squid_401_rejection_detected(mock_request): + """Test that Squid SSRF rejection with 401 is also converted to ToolSSRFError.""" + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {"server": "squid/5.2"} + mock_request.return_value = mock_response + + with pytest.raises(ToolSSRFError) as exc_info: + make_request("GET", "http://192.168.1.1/api") + + assert "SSRF protection" in str(exc_info.value) + assert "squid.conf.template" in str(exc_info.value) + + +@patch("httpx.Client.request") +def test_regular_403_not_treated_as_ssrf(mock_request): + """Test that regular 403 responses (not from Squid) are returned normally.""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.headers = {"server": "nginx/1.21.0"} # Not Squid + mock_request.return_value = mock_response + + # Should not raise ToolSSRFError + response = make_request("GET", "http://example.com/api") + assert response.status_code == 403 + + class TestIsPrivateOrLocalAddress: """Test cases for SSRF protection function.""" diff --git a/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py b/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py deleted file mode 100644 index da2d8e33fc..0000000000 --- a/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Unit tests for SSRF protection in API schema parser.""" - -import pytest -from flask import Flask - -from core.tools.errors import ToolSSRFError -from core.tools.utils.parser import ApiBasedToolSchemaParser - - -@pytest.fixture -def flask_app(): - """Create a Flask app for testing.""" - app = Flask(__name__) - return app - - -class TestApiSchemaParserSSRF: - """Test SSRF protection in API schema parser.""" - - def test_openapi_with_private_ip_blocked(self, flask_app): - """Test that OpenAPI schema with private IP is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://192.168.1.1/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "192.168.1.1" in str(exc_info.value) - assert "private or local network address" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_localhost_blocked(self, flask_app): - """Test that OpenAPI schema with localhost is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://localhost:8080/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "localhost" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_local_domain_blocked(self, flask_app): - """Test that OpenAPI schema with .local domain is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://myserver.local/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "myserver.local" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_10_network_blocked(self, flask_app): - """Test that OpenAPI schema with 10.x.x.x network is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://10.0.0.5/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "10.0.0.5" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_public_url_allowed(self, flask_app): - """Test that OpenAPI schema with public URL is allowed.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: https://api.example.com/v1 -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - # Should not raise any exception - result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - assert result is not None - assert len(result) > 0 - - def test_swagger_with_private_ip_blocked(self, flask_app): - """Test that Swagger schema with private IP is blocked.""" - swagger_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://172.16.0.1/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema) - - assert "172.16.0.1" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_multiple_servers_one_private(self, flask_app): - """Test that OpenAPI with multiple servers including one private is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: https://api.example.com/v1 - - url: http://192.168.1.100/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "192.168.1.100" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_json_format_with_private_ip_blocked(self, flask_app): - """Test that JSON format OpenAPI schema with private IP is blocked.""" - openapi_json = """{ - "openapi": "3.0.0", - "info": { - "title": "Test API", - "version": "1.0.0" - }, - "servers": [ - { - "url": "http://127.0.0.1:8080/api" - } - ], - "paths": { - "/test": { - "get": { - "summary": "Test endpoint", - "operationId": "testGet", - "responses": { - "200": { - "description": "Success" - } - } - } - } - } -}""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json) - - assert "127.0.0.1" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) From fc260fab9753c11324775729fb876952921ed860 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 09:09:29 +0000 Subject: [PATCH 6/8] [autofix.ci] apply automated fixes --- api/core/helper/ssrf_proxy.py | 2 +- api/tests/unit_tests/core/helper/test_ssrf_proxy.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 442b2145a5..6c98aea1be 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -99,7 +99,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): # Check if this is a Squid SSRF rejection server_header = response.headers.get("server", "").lower() via_header = response.headers.get("via", "").lower() - + # Squid typically identifies itself in Server or Via headers if "squid" in server_header or "squid" in via_header: raise ToolSSRFError( diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index e99bc93c67..e2e4da78c8 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -63,7 +63,7 @@ def test_squid_ssrf_rejection_detected(mock_request): with pytest.raises(ToolSSRFError) as exc_info: make_request("GET", "http://192.168.1.1/api") - + assert "blocked by SSRF protection" in str(exc_info.value) assert "192.168.1.1" in str(exc_info.value) assert "squid.conf.template" in str(exc_info.value) @@ -79,7 +79,7 @@ def test_squid_ssrf_rejection_via_header(mock_request): with pytest.raises(ToolSSRFError) as exc_info: make_request("GET", "http://10.0.0.1/api") - + assert "SSRF protection" in str(exc_info.value) @@ -93,7 +93,7 @@ def test_squid_401_rejection_detected(mock_request): with pytest.raises(ToolSSRFError) as exc_info: make_request("GET", "http://192.168.1.1/api") - + assert "SSRF protection" in str(exc_info.value) assert "squid.conf.template" in str(exc_info.value) From 309875650dee41a16b401474865ee4cdebce39f5 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 12 Dec 2025 17:10:31 +0800 Subject: [PATCH 7/8] use squid for ssrf --- api/tests/unit_tests/core/tools/utils/test_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 8e3e486018..f39158aa59 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -14,7 +14,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, - "servers": [{"url": "https://api.example.com"}], + "servers": [{"url": "http://localhost:3000"}], "paths": { "/": { "get": { @@ -60,7 +60,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, - "servers": [{"url": "https://api.example.com"}], + "servers": [{"url": "http://localhost:3000"}], "paths": { "/api/resource": { "get": { From 3f7d46358c986f55479df916603afb0141277455 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Mon, 15 Dec 2025 10:14:29 +0800 Subject: [PATCH 8/8] fix test --- .../unit_tests/core/helper/test_ssrf_proxy.py | 144 +----------------- 1 file changed, 1 insertion(+), 143 deletions(-) diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index e2e4da78c8..37749f0c66 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -3,8 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request -from core.tools.errors import ToolSSRFError +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request @patch("httpx.Client.request") @@ -51,144 +50,3 @@ def test_retry_logic_success(mock_request): assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 assert mock_request.call_args_list[0][1].get("method") == "GET" - - -@patch("httpx.Client.request") -def test_squid_ssrf_rejection_detected(mock_request): - """Test that Squid SSRF rejection (403) is converted to ToolSSRFError.""" - mock_response = MagicMock() - mock_response.status_code = 403 - mock_response.headers = {"server": "squid/5.2", "via": "1.1 squid"} - mock_request.return_value = mock_response - - with pytest.raises(ToolSSRFError) as exc_info: - make_request("GET", "http://192.168.1.1/api") - - assert "blocked by SSRF protection" in str(exc_info.value) - assert "192.168.1.1" in str(exc_info.value) - assert "squid.conf.template" in str(exc_info.value) - - -@patch("httpx.Client.request") -def test_squid_ssrf_rejection_via_header(mock_request): - """Test detection via Via header when Server header is not present.""" - mock_response = MagicMock() - mock_response.status_code = 403 - mock_response.headers = {"via": "1.1 squid-proxy (squid/5.2)"} - mock_request.return_value = mock_response - - with pytest.raises(ToolSSRFError) as exc_info: - make_request("GET", "http://10.0.0.1/api") - - assert "SSRF protection" in str(exc_info.value) - - -@patch("httpx.Client.request") -def test_squid_401_rejection_detected(mock_request): - """Test that Squid SSRF rejection with 401 is also converted to ToolSSRFError.""" - mock_response = MagicMock() - mock_response.status_code = 401 - mock_response.headers = {"server": "squid/5.2"} - mock_request.return_value = mock_response - - with pytest.raises(ToolSSRFError) as exc_info: - make_request("GET", "http://192.168.1.1/api") - - assert "SSRF protection" in str(exc_info.value) - assert "squid.conf.template" in str(exc_info.value) - - -@patch("httpx.Client.request") -def test_regular_403_not_treated_as_ssrf(mock_request): - """Test that regular 403 responses (not from Squid) are returned normally.""" - mock_response = MagicMock() - mock_response.status_code = 403 - mock_response.headers = {"server": "nginx/1.21.0"} # Not Squid - mock_request.return_value = mock_response - - # Should not raise ToolSSRFError - response = make_request("GET", "http://example.com/api") - assert response.status_code == 403 - - -class TestIsPrivateOrLocalAddress: - """Test cases for SSRF protection function.""" - - def test_localhost_variants(self): - """Test that localhost variants are detected as private.""" - assert is_private_or_local_address("http://localhost/api") is True - assert is_private_or_local_address("http://127.0.0.1/api") is True - assert is_private_or_local_address("http://[::1]/api") is True - assert is_private_or_local_address("https://localhost:8080/") is True - - def test_private_ipv4_ranges(self): - """Test that private IPv4 ranges are detected.""" - # 10.0.0.0/8 - assert is_private_or_local_address("http://10.0.0.1/api") is True - assert is_private_or_local_address("http://10.255.255.255/api") is True - - # 172.16.0.0/12 - assert is_private_or_local_address("http://172.16.0.1/api") is True - assert is_private_or_local_address("http://172.31.255.255/api") is True - - # 192.168.0.0/16 - assert is_private_or_local_address("http://192.168.0.1/api") is True - assert is_private_or_local_address("http://192.168.255.255/api") is True - - # 169.254.0.0/16 (link-local) - assert is_private_or_local_address("http://169.254.1.1/api") is True - - def test_local_domains(self): - """Test that .local domains are detected as private.""" - assert is_private_or_local_address("http://myserver.local/api") is True - assert is_private_or_local_address("https://test.local:8080/") is True - - def test_public_addresses(self): - """Test that public addresses are not detected as private.""" - assert is_private_or_local_address("http://example.com/api") is False - assert is_private_or_local_address("https://api.openai.com/v1") is False - assert is_private_or_local_address("http://8.8.8.8/") is False - assert is_private_or_local_address("https://1.1.1.1/") is False - assert is_private_or_local_address("http://93.184.216.34/") is False - - def test_edge_cases(self): - """Test edge cases and invalid inputs.""" - # Empty or None - assert is_private_or_local_address("") is False - assert is_private_or_local_address(None) is False - - # Invalid URLs - assert is_private_or_local_address("not-a-url") is False - assert is_private_or_local_address("://invalid") is False - - def test_ipv6_private_ranges(self): - """Test that private IPv6 ranges are detected.""" - # IPv6 loopback - assert is_private_or_local_address("http://[::1]/api") is True - - # IPv6 link-local (fe80::/10) - assert is_private_or_local_address("http://[fe80::1]/api") is True - - # IPv6 unique local (fc00::/7) - assert is_private_or_local_address("http://[fc00::1]/api") is True - assert is_private_or_local_address("http://[fd00::1]/api") is True - - def test_public_ipv6(self): - """Test that public IPv6 addresses are not detected as private.""" - # Public IPv6 addresses (real examples) - # Google Public DNS IPv6 - assert is_private_or_local_address("http://[2001:4860:4860::8888]/api") is False - # Cloudflare DNS IPv6 - assert is_private_or_local_address("http://[2606:4700:4700::1111]/api") is False - - def test_url_with_ports(self): - """Test URLs with custom ports.""" - assert is_private_or_local_address("http://localhost:8080/api") is True - assert is_private_or_local_address("http://192.168.1.1:3000/") is True - assert is_private_or_local_address("https://example.com:443/api") is False - - def test_url_schemes(self): - """Test different URL schemes.""" - assert is_private_or_local_address("https://127.0.0.1/api") is True - assert is_private_or_local_address("http://127.0.0.1/api") is True - assert is_private_or_local_address("https://example.com/api") is False