add internal ip filter when parse tool schema

This commit is contained in:
Yansong Zhang 2025-12-12 11:37:30 +08:00
parent 64f5e34096
commit e66ef9145b
1 changed files with 21 additions and 14 deletions

View File

@ -16,6 +16,25 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
class ApiBasedToolSchemaParser:
@staticmethod
def _validate_server_urls(servers: list[dict]) -> None:
"""
Validate server URLs to prevent SSRF attacks.
Args:
servers: List of server dictionaries containing 'url' keys
Raises:
ToolSSRFError: If any server URL points to a private or local network address
"""
for server in servers:
server_url = server.get("url", "")
if server_url and is_private_or_local_address(server_url):
raise ToolSSRFError(
f"Server URL '{server_url}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
@ -30,13 +49,7 @@ class ApiBasedToolSchemaParser:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
# SSRF Protection: Validate all server URLs before processing
for server in openapi["servers"]:
server_url_to_check = server.get("url", "")
if server_url_to_check and is_private_or_local_address(server_url_to_check):
raise ToolSSRFError(
f"Server URL '{server_url_to_check}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
ApiBasedToolSchemaParser._validate_server_urls(openapi["servers"])
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
@ -298,13 +311,7 @@ class ApiBasedToolSchemaParser:
raise ToolApiSchemaError("No server found in the swagger yaml.")
# SSRF Protection: Validate all server URLs before processing
for server in servers:
server_url_to_check = server.get("url", "")
if server_url_to_check and is_private_or_local_address(server_url_to_check):
raise ToolSSRFError(
f"Server URL '{server_url_to_check}' points to a private or local network address, "
"which is not allowed for security reasons (SSRF protection)."
)
ApiBasedToolSchemaParser._validate_server_urls(servers)
converted_openapi: dict[str, Any] = {
"openapi": "3.0.0",