diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 691882d522..442b2145a5 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,79 +2,17 @@ Proxy requests to avoid SSRF """ -import ipaddress import logging import time -from urllib.parse import urlparse import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) - -def is_private_or_local_address(url: str) -> bool: - """ - Check if URL points to a private/local network address (SSRF protection). - - This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks - by detecting private IP addresses, localhost, and local network domains. - - Args: - url: The URL string to check - - Returns: - True if the URL points to a private/local address, False otherwise - - Examples: - >>> is_private_or_local_address("http://localhost/api") - True - >>> is_private_or_local_address("http://192.168.1.1/api") - True - >>> is_private_or_local_address("https://example.com/api") - False - """ - if not url: - return False - - try: - parsed = urlparse(url) - hostname = parsed.hostname - - if not hostname: - return False - - hostname_lower = hostname.lower() - - # Check for localhost variants - if hostname_lower in ("localhost", "127.0.0.1", "::1"): - return True - - # Check for .local domains (link-local) - if hostname_lower.endswith(".local"): - return True - - # Try to parse as IP address - try: - ip = ipaddress.ip_address(hostname) - - # Check if it's a private, loopback, or link-local address. - # - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 - # - Loopback: 127.0.0.0/8, ::1 - # - Link-local: 169.254.0.0/16, fe80::/10 - return ip.is_private or ip.is_loopback or ip.is_link_local - except ValueError: - # Not a valid IP address, might be a domain name - # Domain names could resolve to private IPs, but we only check the literal hostname here - # For more thorough checks, DNS resolution would be needed (but adds latency) - return False - - except (ValueError, TypeError, AttributeError): - return False - - SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 @@ -156,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index c75d5090f3..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -8,33 +8,13 @@ import httpx from flask import request from yaml import YAMLError, safe_load -from core.helper.ssrf_proxy import is_private_or_local_address from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter -from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError class ApiBasedToolSchemaParser: - @staticmethod - def _validate_server_urls(servers: list[dict]) -> None: - """ - Validate server URLs to prevent SSRF attacks. - - Args: - servers: List of server dictionaries containing 'url' keys - - Raises: - ToolSSRFError: If any server URL points to a private or local network address - """ - for server in servers: - server_url = server.get("url", "") - if server_url and is_private_or_local_address(server_url): - raise ToolSSRFError( - f"Server URL '{server_url}' points to a private or local network address, " - "which is not allowed for security reasons (SSRF protection)." - ) - @staticmethod def parse_openapi_to_tool_bundle( openapi: dict, extra_info: dict | None = None, warning: dict | None = None @@ -48,9 +28,6 @@ class ApiBasedToolSchemaParser: if len(openapi["servers"]) == 0: raise ToolProviderNotFoundError("No server found in the openapi yaml.") - # SSRF Protection: Validate all server URLs before processing - ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"]) - server_url = openapi["servers"][0]["url"] request_env = request.headers.get("X-Request-Env") if request_env: @@ -310,9 +287,6 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - # SSRF Protection: Validate all server URLs before processing - ApiBasedToolSchemaParser._validate_server_urls(servers) - converted_openapi: dict[str, Any] = { "openapi": "3.0.0", "info": { @@ -386,13 +360,6 @@ class ApiBasedToolSchemaParser: if api_type != "openapi": raise ToolNotSupportedError("Only openapi is supported now.") - # SSRF Protection: Validate API URL before making HTTP request - if is_private_or_local_address(api_url): - raise ToolSSRFError( - f"API URL '{api_url}' points to a private or local network address, " - "which is not allowed for security reasons (SSRF protection)." - ) - # get openapi yaml response = httpx.get( api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 @@ -457,8 +424,6 @@ class ApiBasedToolSchemaParser: return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - except ToolSSRFError: - raise # openapi parse error, fallback to swagger try: @@ -471,18 +436,12 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - except ToolSSRFError: - # SSRF protection errors should be raised immediately, don't fallback - raise # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( json_dumps(loaded_content), extra_info=extra_info, warning=warning ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN - except ToolSSRFError: - # SSRF protection errors should be raised immediately, don't fallback - raise except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index caf92d5af6..e99bc93c67 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request +from core.tools.errors import ToolSSRFError @patch("httpx.Client.request") @@ -52,6 +53,64 @@ def test_retry_logic_success(mock_request): assert mock_request.call_args_list[0][1].get("method") == "GET" +@patch("httpx.Client.request") +def test_squid_ssrf_rejection_detected(mock_request): + """Test that Squid SSRF rejection (403) is converted to ToolSSRFError.""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.headers = {"server": "squid/5.2", "via": "1.1 squid"} + mock_request.return_value = mock_response + + with pytest.raises(ToolSSRFError) as exc_info: + make_request("GET", "http://192.168.1.1/api") + + assert "blocked by SSRF protection" in str(exc_info.value) + assert "192.168.1.1" in str(exc_info.value) + assert "squid.conf.template" in str(exc_info.value) + + +@patch("httpx.Client.request") +def test_squid_ssrf_rejection_via_header(mock_request): + """Test detection via Via header when Server header is not present.""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.headers = {"via": "1.1 squid-proxy (squid/5.2)"} + mock_request.return_value = mock_response + + with pytest.raises(ToolSSRFError) as exc_info: + make_request("GET", "http://10.0.0.1/api") + + assert "SSRF protection" in str(exc_info.value) + + +@patch("httpx.Client.request") +def test_squid_401_rejection_detected(mock_request): + """Test that Squid SSRF rejection with 401 is also converted to ToolSSRFError.""" + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {"server": "squid/5.2"} + mock_request.return_value = mock_response + + with pytest.raises(ToolSSRFError) as exc_info: + make_request("GET", "http://192.168.1.1/api") + + assert "SSRF protection" in str(exc_info.value) + assert "squid.conf.template" in str(exc_info.value) + + +@patch("httpx.Client.request") +def test_regular_403_not_treated_as_ssrf(mock_request): + """Test that regular 403 responses (not from Squid) are returned normally.""" + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.headers = {"server": "nginx/1.21.0"} # Not Squid + mock_request.return_value = mock_response + + # Should not raise ToolSSRFError + response = make_request("GET", "http://example.com/api") + assert response.status_code == 403 + + class TestIsPrivateOrLocalAddress: """Test cases for SSRF protection function.""" diff --git a/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py b/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py deleted file mode 100644 index da2d8e33fc..0000000000 --- a/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Unit tests for SSRF protection in API schema parser.""" - -import pytest -from flask import Flask - -from core.tools.errors import ToolSSRFError -from core.tools.utils.parser import ApiBasedToolSchemaParser - - -@pytest.fixture -def flask_app(): - """Create a Flask app for testing.""" - app = Flask(__name__) - return app - - -class TestApiSchemaParserSSRF: - """Test SSRF protection in API schema parser.""" - - def test_openapi_with_private_ip_blocked(self, flask_app): - """Test that OpenAPI schema with private IP is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://192.168.1.1/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "192.168.1.1" in str(exc_info.value) - assert "private or local network address" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_localhost_blocked(self, flask_app): - """Test that OpenAPI schema with localhost is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://localhost:8080/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "localhost" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_local_domain_blocked(self, flask_app): - """Test that OpenAPI schema with .local domain is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://myserver.local/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "myserver.local" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_10_network_blocked(self, flask_app): - """Test that OpenAPI schema with 10.x.x.x network is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://10.0.0.5/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "10.0.0.5" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_public_url_allowed(self, flask_app): - """Test that OpenAPI schema with public URL is allowed.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: https://api.example.com/v1 -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - # Should not raise any exception - result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - assert result is not None - assert len(result) > 0 - - def test_swagger_with_private_ip_blocked(self, flask_app): - """Test that Swagger schema with private IP is blocked.""" - swagger_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: http://172.16.0.1/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema) - - assert "172.16.0.1" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_with_multiple_servers_one_private(self, flask_app): - """Test that OpenAPI with multiple servers including one private is blocked.""" - openapi_schema = """ -openapi: 3.0.0 -info: - title: Test API - version: 1.0.0 -servers: - - url: https://api.example.com/v1 - - url: http://192.168.1.100/api -paths: - /test: - get: - summary: Test endpoint - operationId: testGet - responses: - '200': - description: Success -""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) - - assert "192.168.1.100" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value) - - def test_openapi_json_format_with_private_ip_blocked(self, flask_app): - """Test that JSON format OpenAPI schema with private IP is blocked.""" - openapi_json = """{ - "openapi": "3.0.0", - "info": { - "title": "Test API", - "version": "1.0.0" - }, - "servers": [ - { - "url": "http://127.0.0.1:8080/api" - } - ], - "paths": { - "/test": { - "get": { - "summary": "Test endpoint", - "operationId": "testGet", - "responses": { - "200": { - "description": "Success" - } - } - } - } - } -}""" - with flask_app.test_request_context(): - with pytest.raises(ToolSSRFError) as exc_info: - ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json) - - assert "127.0.0.1" in str(exc_info.value) - assert "SSRF protection" in str(exc_info.value)