mirror of https://github.com/langgenius/dify.git
Merge f118725c59 into 446df6b50d
This commit is contained in:
commit
2462a7d9ae
|
|
@ -5,6 +5,7 @@ According to ClickZetta's permission model, different Volume types have differen
|
|||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from enum import StrEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -266,8 +267,14 @@ class VolumePermissionManager:
|
|||
# If no explicit permissions found, try executing a simple query to verify permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
# Use parameterized query to prevent SQL injection
|
||||
# Note: Table names cannot be parameterized in most SQL dialects,
|
||||
# so we validate the table_name to ensure it's a safe identifier
|
||||
if self._is_valid_table_name(table_name):
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
else:
|
||||
logger.warning("Invalid table name format: %s", table_name)
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
|
|
@ -354,6 +361,10 @@ class VolumePermissionManager:
|
|||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check Volume permissions
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
# Validate volume_name to prevent SQL injection
|
||||
if not self._is_valid_volume_name(volume_name):
|
||||
logger.warning("Invalid volume name format: %s", volume_name)
|
||||
return permissions
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
|
|
@ -586,6 +597,47 @@ class VolumePermissionManager:
|
|||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def _is_valid_table_name(self, table_name: str) -> bool:
|
||||
"""Validate table name to prevent SQL injection
|
||||
|
||||
Table names should only contain alphanumeric characters, underscores,
|
||||
and dots (for schema.table notation).
|
||||
|
||||
Args:
|
||||
table_name: Table name to validate
|
||||
|
||||
Returns:
|
||||
True if table name is valid, False otherwise
|
||||
"""
|
||||
if not table_name or not isinstance(table_name, str):
|
||||
return False
|
||||
|
||||
# Allow alphanumeric characters, underscores, and dots (for schema.table notation)
|
||||
# Also allow backticks for quoted identifiers (common in SQL)
|
||||
pattern = r"^[a-zA-Z0-9_`\.]+$"
|
||||
|
||||
return bool(re.match(pattern, table_name))
|
||||
|
||||
def _is_valid_volume_name(self, volume_name: str) -> bool:
|
||||
"""Validate volume name to prevent SQL injection
|
||||
|
||||
Volume names should only contain alphanumeric characters, underscores,
|
||||
and hyphens.
|
||||
|
||||
Args:
|
||||
volume_name: Volume name to validate
|
||||
|
||||
Returns:
|
||||
True if volume name is valid, False otherwise
|
||||
"""
|
||||
if not volume_name or not isinstance(volume_name, str):
|
||||
return False
|
||||
|
||||
# Allow alphanumeric characters, underscores, hyphens, and dots
|
||||
pattern = r"^[a-zA-Z0-9_\-\.]+$"
|
||||
|
||||
return bool(re.match(pattern, volume_name))
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool:
|
||||
"""Validate operation permission
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue