From 3f3b9beeff3df0220987699873ee4817d0ffa955 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Fri, 12 Dec 2025 11:24:25 +0800 Subject: [PATCH] add internal ip filter when parse tool schema --- api/core/helper/ssrf_proxy.py | 72 ++++++ api/core/tools/errors.py | 4 + api/core/tools/utils/parser.py | 39 ++- .../unit_tests/core/helper/test_ssrf_proxy.py | 85 ++++++- .../core/tools/utils/test_parser_ssrf.py | 228 ++++++++++++++++++ 5 files changed, 424 insertions(+), 4 deletions(-) create mode 100644 api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..ba91ebd9d4 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,8 +2,10 @@ Proxy requests to avoid SSRF """ +import ipaddress import logging import time +from urllib.parse import urlparse import httpx @@ -12,6 +14,76 @@ from core.helper.http_client_pooling import get_pooled_http_client logger = logging.getLogger(__name__) + +def is_private_or_local_address(url: str) -> bool: + """ + Check if URL points to a private/local network address (SSRF protection). + + This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks + by detecting private IP addresses, localhost, and local network domains. + + Args: + url: The URL string to check + + Returns: + True if the URL points to a private/local address, False otherwise + + Examples: + >>> is_private_or_local_address("http://localhost/api") + True + >>> is_private_or_local_address("http://192.168.1.1/api") + True + >>> is_private_or_local_address("https://example.com/api") + False + """ + if not url: + return False + + try: + parsed = urlparse(url) + hostname = parsed.hostname + + if not hostname: + return False + + hostname_lower = hostname.lower() + + # Check for localhost variants + if hostname_lower in ("localhost", "127.0.0.1", "::1"): + return True + + # Check for .local domains (link-local) + if hostname_lower.endswith(".local"): + return True + + # Try to parse as IP address + try: + ip = ipaddress.ip_address(hostname) + + # Check if it's a private IP (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 for IPv4) + # For IPv6: fc00::/7 (unique local addresses) + if ip.is_private: + return True + + # Check if it's loopback (127.0.0.0/8 for IPv4, ::1 for IPv6) + if ip.is_loopback: + return True + + # Check if it's link-local (169.254.0.0/16 for IPv4, fe80::/10 for IPv6) + if ip.is_link_local: + return True + + return False + except ValueError: + # Not a valid IP address, might be a domain name + # Domain names could resolve to private IPs, but we only check the literal hostname here + # For more thorough checks, DNS resolution would be needed (but adds latency) + return False + + except (ValueError, TypeError, AttributeError): + return False + + SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..b2c7d3db80 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -8,10 +8,11 @@ import httpx from flask import request from yaml import YAMLError, safe_load +from core.helper.ssrf_proxy import is_private_or_local_address from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter -from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError, ToolSSRFError class ApiBasedToolSchemaParser: @@ -28,6 +29,15 @@ class ApiBasedToolSchemaParser: if len(openapi["servers"]) == 0: raise ToolProviderNotFoundError("No server found in the openapi yaml.") + # SSRF Protection: Validate all server URLs before processing + for server in openapi["servers"]: + server_url_to_check = server.get("url", "") + if server_url_to_check and is_private_or_local_address(server_url_to_check): + raise ToolSSRFError( + f"Server URL '{server_url_to_check}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + server_url = openapi["servers"][0]["url"] request_env = request.headers.get("X-Request-Env") if request_env: @@ -287,6 +297,15 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") + # SSRF Protection: Validate all server URLs before processing + for server in servers: + server_url_to_check = server.get("url", "") + if server_url_to_check and is_private_or_local_address(server_url_to_check): + raise ToolSSRFError( + f"Server URL '{server_url_to_check}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + converted_openapi: dict[str, Any] = { "openapi": "3.0.0", "info": { @@ -360,6 +379,13 @@ class ApiBasedToolSchemaParser: if api_type != "openapi": raise ToolNotSupportedError("Only openapi is supported now.") + # SSRF Protection: Validate API URL before making HTTP request + if is_private_or_local_address(api_url): + raise ToolSSRFError( + f"API URL '{api_url}' points to a private or local network address, " + "which is not allowed for security reasons (SSRF protection)." + ) + # get openapi yaml response = httpx.get( api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 @@ -424,8 +450,10 @@ class ApiBasedToolSchemaParser: return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e + except ToolSSRFError: + raise - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,13 +464,18 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - + except ToolSSRFError: + # SSRF protection errors should be raised immediately, don't fallback + raise # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( json_dumps(loaded_content), extra_info=extra_info, warning=warning ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN + except ToolSSRFError: + # SSRF protection errors should be raised immediately, don't fallback + raise except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 37749f0c66..caf92d5af6 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, is_private_or_local_address, make_request @patch("httpx.Client.request") @@ -50,3 +50,86 @@ def test_retry_logic_success(mock_request): assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 assert mock_request.call_args_list[0][1].get("method") == "GET" + + +class TestIsPrivateOrLocalAddress: + """Test cases for SSRF protection function.""" + + def test_localhost_variants(self): + """Test that localhost variants are detected as private.""" + assert is_private_or_local_address("http://localhost/api") is True + assert is_private_or_local_address("http://127.0.0.1/api") is True + assert is_private_or_local_address("http://[::1]/api") is True + assert is_private_or_local_address("https://localhost:8080/") is True + + def test_private_ipv4_ranges(self): + """Test that private IPv4 ranges are detected.""" + # 10.0.0.0/8 + assert is_private_or_local_address("http://10.0.0.1/api") is True + assert is_private_or_local_address("http://10.255.255.255/api") is True + + # 172.16.0.0/12 + assert is_private_or_local_address("http://172.16.0.1/api") is True + assert is_private_or_local_address("http://172.31.255.255/api") is True + + # 192.168.0.0/16 + assert is_private_or_local_address("http://192.168.0.1/api") is True + assert is_private_or_local_address("http://192.168.255.255/api") is True + + # 169.254.0.0/16 (link-local) + assert is_private_or_local_address("http://169.254.1.1/api") is True + + def test_local_domains(self): + """Test that .local domains are detected as private.""" + assert is_private_or_local_address("http://myserver.local/api") is True + assert is_private_or_local_address("https://test.local:8080/") is True + + def test_public_addresses(self): + """Test that public addresses are not detected as private.""" + assert is_private_or_local_address("http://example.com/api") is False + assert is_private_or_local_address("https://api.openai.com/v1") is False + assert is_private_or_local_address("http://8.8.8.8/") is False + assert is_private_or_local_address("https://1.1.1.1/") is False + assert is_private_or_local_address("http://93.184.216.34/") is False + + def test_edge_cases(self): + """Test edge cases and invalid inputs.""" + # Empty or None + assert is_private_or_local_address("") is False + assert is_private_or_local_address(None) is False + + # Invalid URLs + assert is_private_or_local_address("not-a-url") is False + assert is_private_or_local_address("://invalid") is False + + def test_ipv6_private_ranges(self): + """Test that private IPv6 ranges are detected.""" + # IPv6 loopback + assert is_private_or_local_address("http://[::1]/api") is True + + # IPv6 link-local (fe80::/10) + assert is_private_or_local_address("http://[fe80::1]/api") is True + + # IPv6 unique local (fc00::/7) + assert is_private_or_local_address("http://[fc00::1]/api") is True + assert is_private_or_local_address("http://[fd00::1]/api") is True + + def test_public_ipv6(self): + """Test that public IPv6 addresses are not detected as private.""" + # Public IPv6 addresses (real examples) + # Google Public DNS IPv6 + assert is_private_or_local_address("http://[2001:4860:4860::8888]/api") is False + # Cloudflare DNS IPv6 + assert is_private_or_local_address("http://[2606:4700:4700::1111]/api") is False + + def test_url_with_ports(self): + """Test URLs with custom ports.""" + assert is_private_or_local_address("http://localhost:8080/api") is True + assert is_private_or_local_address("http://192.168.1.1:3000/") is True + assert is_private_or_local_address("https://example.com:443/api") is False + + def test_url_schemes(self): + """Test different URL schemes.""" + assert is_private_or_local_address("https://127.0.0.1/api") is True + assert is_private_or_local_address("http://127.0.0.1/api") is True + assert is_private_or_local_address("https://example.com/api") is False diff --git a/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py b/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py new file mode 100644 index 0000000000..da2d8e33fc --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_parser_ssrf.py @@ -0,0 +1,228 @@ +"""Unit tests for SSRF protection in API schema parser.""" + +import pytest +from flask import Flask + +from core.tools.errors import ToolSSRFError +from core.tools.utils.parser import ApiBasedToolSchemaParser + + +@pytest.fixture +def flask_app(): + """Create a Flask app for testing.""" + app = Flask(__name__) + return app + + +class TestApiSchemaParserSSRF: + """Test SSRF protection in API schema parser.""" + + def test_openapi_with_private_ip_blocked(self, flask_app): + """Test that OpenAPI schema with private IP is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://192.168.1.1/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "192.168.1.1" in str(exc_info.value) + assert "private or local network address" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_localhost_blocked(self, flask_app): + """Test that OpenAPI schema with localhost is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://localhost:8080/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "localhost" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_local_domain_blocked(self, flask_app): + """Test that OpenAPI schema with .local domain is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://myserver.local/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "myserver.local" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_10_network_blocked(self, flask_app): + """Test that OpenAPI schema with 10.x.x.x network is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://10.0.0.5/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "10.0.0.5" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_public_url_allowed(self, flask_app): + """Test that OpenAPI schema with public URL is allowed.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com/v1 +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + # Should not raise any exception + result, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + assert result is not None + assert len(result) > 0 + + def test_swagger_with_private_ip_blocked(self, flask_app): + """Test that Swagger schema with private IP is blocked.""" + swagger_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: http://172.16.0.1/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(swagger_schema) + + assert "172.16.0.1" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_with_multiple_servers_one_private(self, flask_app): + """Test that OpenAPI with multiple servers including one private is blocked.""" + openapi_schema = """ +openapi: 3.0.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com/v1 + - url: http://192.168.1.100/api +paths: + /test: + get: + summary: Test endpoint + operationId: testGet + responses: + '200': + description: Success +""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_schema) + + assert "192.168.1.100" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value) + + def test_openapi_json_format_with_private_ip_blocked(self, flask_app): + """Test that JSON format OpenAPI schema with private IP is blocked.""" + openapi_json = """{ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://127.0.0.1:8080/api" + } + ], + "paths": { + "/test": { + "get": { + "summary": "Test endpoint", + "operationId": "testGet", + "responses": { + "200": { + "description": "Success" + } + } + } + } + } +}""" + with flask_app.test_request_context(): + with pytest.raises(ToolSSRFError) as exc_info: + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_json) + + assert "127.0.0.1" in str(exc_info.value) + assert "SSRF protection" in str(exc_info.value)