diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py index 502f0bb46b..cda2d1ad1e 100644 --- a/api/extensions/ext_logstore.py +++ b/api/extensions/ext_logstore.py @@ -10,6 +10,7 @@ import os from dotenv import load_dotenv +from configs import dify_config from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -19,12 +20,17 @@ def is_enabled() -> bool: """ Check if logstore extension is enabled. + Logstore is considered enabled when: + 1. All required Aliyun SLS environment variables are set + 2. At least one repository configuration points to a logstore implementation + Returns: - True if all required Aliyun SLS environment variables are set, False otherwise + True if logstore should be initialized, False otherwise """ # Load environment variables from .env file load_dotenv() + # Check if Aliyun SLS connection parameters are configured required_vars = [ "ALIYUN_SLS_ACCESS_KEY_ID", "ALIYUN_SLS_ACCESS_KEY_SECRET", @@ -33,24 +39,32 @@ def is_enabled() -> bool: "ALIYUN_SLS_PROJECT_NAME", ] - all_set = all(os.environ.get(var) for var in required_vars) + sls_vars_set = all(os.environ.get(var) for var in required_vars) - if not all_set: - logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") + if not sls_vars_set: + return False - return all_set + # Check if any repository configuration points to logstore implementation + repository_configs = [ + dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY, + dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_RUN_REPOSITORY, + ] + + uses_logstore = any("logstore" in config.lower() for config in repository_configs) + + if not uses_logstore: + return False + + logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore") + return True def init_app(app: DifyApp): """ Initialize logstore on application startup. - - This function: - 1. Creates Aliyun SLS project if it doesn't exist - 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist - 3. Creates indexes with field configurations based on PostgreSQL table structures - - This operation is idempotent and only executes once during application startup. + If initialization fails, the application continues running without logstore features. Args: app: The Dify application instance @@ -58,17 +72,23 @@ def init_app(app: DifyApp): try: from extensions.logstore.aliyun_logstore import AliyunLogStore - logger.info("Initializing logstore...") + logger.info("Initializing Aliyun SLS Logstore...") - # Create logstore client and initialize project/logstores/indexes + # Create logstore client and initialize resources logstore_client = AliyunLogStore() logstore_client.init_project_logstore() - # Attach to app for potential later use app.extensions["logstore"] = logstore_client logger.info("Logstore initialized successfully") + except Exception: - logger.exception("Failed to initialize logstore") - # Don't raise - allow application to continue even if logstore init fails - # This ensures that the application can still run if logstore is misconfigured + logger.exception( + "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. " + "Application will continue but logstore features will NOT work.", + os.environ.get("ALIYUN_SLS_ENDPOINT"), + os.environ.get("ALIYUN_SLS_REGION"), + os.environ.get("ALIYUN_SLS_PROJECT_NAME"), + os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"), + ) + # Don't raise - allow application to continue even if logstore setup fails diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 8c64a25be4..f6a4765f14 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import os +import socket import threading import time from collections.abc import Sequence @@ -179,9 +180,18 @@ class AliyunLogStore: self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) - self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + self.log_enabled: bool = ( + os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true" + ) self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" + # Get timeout configuration + check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30)) + + # Pre-check endpoint connectivity to prevent indefinite hangs + self._check_endpoint_connectivity(self.endpoint, check_timeout) + # Initialize SDK client self.client = LogClient( self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region @@ -199,6 +209,49 @@ class AliyunLogStore: self.__class__._initialized = True + @staticmethod + def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None: + """ + Check if the SLS endpoint is reachable before creating LogClient. + Prevents indefinite hangs when the endpoint is unreachable. + + Args: + endpoint: SLS endpoint URL + timeout: Connection timeout in seconds + + Raises: + ConnectionError: If endpoint is not reachable + """ + # Parse endpoint URL to extract hostname and port + from urllib.parse import urlparse + + parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}") + hostname = parsed_url.hostname + port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) + + if not hostname: + raise ConnectionError(f"Invalid endpoint URL: {endpoint}") + + sock = None + try: + # Create socket and set timeout + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect((hostname, port)) + except Exception as e: + # Catch all exceptions and provide clear error message + error_type = type(e).__name__ + raise ConnectionError( + f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}" + ) from e + finally: + # Ensure socket is properly closed + if sock: + try: + sock.close() + except Exception: # noqa: S110 + pass # Ignore errors during cleanup + @property def supports_pg_protocol(self) -> bool: """Check if PG protocol is supported and enabled.""" @@ -220,19 +273,16 @@ class AliyunLogStore: try: self._use_pg_protocol = self._pg_client.init_connection() if self._use_pg_protocol: - logger.info("Successfully connected to project %s using PG protocol", self.project_name) + logger.info("Using PG protocol for project %s", self.project_name) # Check if scan_index is enabled for all logstores self._check_and_disable_pg_if_scan_index_disabled() return True else: - logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) + logger.info("Using SDK mode for project %s", self.project_name) return False except Exception as e: - logger.warning( - "Failed to establish PG connection for project %s: %s. Will use SDK mode.", - self.project_name, - str(e), - ) + logger.info("Using SDK mode for project %s", self.project_name) + logger.debug("PG connection details: %s", str(e)) self._use_pg_protocol = False return False @@ -246,10 +296,6 @@ class AliyunLogStore: if self._use_pg_protocol: return - logger.info( - "Attempting delayed PG connection for newly created project %s ...", - self.project_name, - ) self._attempt_pg_connection_init() self.__class__._pg_connection_timer = None @@ -284,11 +330,7 @@ class AliyunLogStore: if project_is_new: # For newly created projects, schedule delayed PG connection self._use_pg_protocol = False - logger.info( - "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.", - self.project_name, - self.__class__._pg_connection_delay, - ) + logger.info("Using SDK mode for project %s (newly created)", self.project_name) if self.__class__._pg_connection_timer is not None: self.__class__._pg_connection_timer.cancel() self.__class__._pg_connection_timer = threading.Timer( @@ -299,7 +341,6 @@ class AliyunLogStore: self.__class__._pg_connection_timer.start() else: # For existing projects, attempt PG connection immediately - logger.info("Project %s already exists. Attempting PG connection...", self.project_name) self._attempt_pg_connection_init() def _check_and_disable_pg_if_scan_index_disabled(self) -> None: @@ -318,9 +359,9 @@ class AliyunLogStore: existing_config = self.get_existing_index_config(logstore_name) if existing_config and not existing_config.scan_index: logger.info( - "Logstore %s has scan_index=false, USE SDK mode for read/write operations. " - "PG protocol requires scan_index to be enabled.", + "Logstore %s requires scan_index enabled, using SDK mode for project %s", logstore_name, + self.project_name, ) self._use_pg_protocol = False # Close PG connection if it was initialized @@ -748,7 +789,6 @@ class AliyunLogStore: reverse=reverse, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " @@ -770,7 +810,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", @@ -845,7 +884,6 @@ class AliyunLogStore: query=full_query, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", @@ -853,8 +891,7 @@ class AliyunLogStore: self.project_name, from_time, to_time, - query, - sql, + full_query, ) try: @@ -865,7 +902,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py index 35aa51ce53..874c20d144 100644 --- a/api/extensions/logstore/aliyun_logstore_pg.py +++ b/api/extensions/logstore/aliyun_logstore_pg.py @@ -7,8 +7,7 @@ from contextlib import contextmanager from typing import Any import psycopg2 -import psycopg2.pool -from psycopg2 import InterfaceError, OperationalError +from sqlalchemy import create_engine from configs import dify_config @@ -16,11 +15,7 @@ logger = logging.getLogger(__name__) class AliyunLogStorePG: - """ - PostgreSQL protocol support for Aliyun SLS LogStore. - - Handles PG connection pooling and operations for regions that support PG protocol. - """ + """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool.""" def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): """ @@ -36,24 +31,11 @@ class AliyunLogStorePG: self._access_key_secret = access_key_secret self._endpoint = endpoint self.project_name = project_name - self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None + self._engine: Any = None # SQLAlchemy Engine self._use_pg_protocol = False def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: - """ - Check if a TCP port is reachable using socket connection. - - This provides a fast check before attempting full database connection, - preventing long waits when connecting to unsupported regions. - - Args: - host: Hostname or IP address - port: Port number - timeout: Connection timeout in seconds (default: 2.0) - - Returns: - True if port is reachable, False otherwise - """ + """Fast TCP port check to avoid long waits on unsupported regions.""" try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) @@ -65,166 +47,101 @@ class AliyunLogStorePG: return False def init_connection(self) -> bool: - """ - Initialize PostgreSQL connection pool for SLS PG protocol support. - - Attempts to connect to SLS using PostgreSQL protocol. If successful, sets - _use_pg_protocol to True and creates a connection pool. If connection fails - (region doesn't support PG protocol or other errors), returns False. - - Returns: - True if PG protocol is supported and initialized, False otherwise - """ + """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support.""" try: - # Extract hostname from endpoint (remove protocol if present) pg_host = self._endpoint.replace("http://", "").replace("https://", "") - # Get pool configuration - pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) + # Pool configuration + pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5)) + max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5)) + pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600)) + pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true" - logger.debug( - "Check PG protocol connection to SLS: host=%s, project=%s", - pg_host, - self.project_name, - ) + logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name) - # Fast port connectivity check before attempting full connection - # This prevents long waits when connecting to unsupported regions + # Fast port check to avoid long waits if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): - logger.info( - "USE SDK mode for read/write operations, host=%s", - pg_host, - ) + logger.debug("Using SDK mode for host=%s", pg_host) return False - # Create connection pool - self._pg_pool = psycopg2.pool.SimpleConnectionPool( - minconn=1, - maxconn=pg_max_connections, - host=pg_host, - port=5432, - database=self.project_name, - user=self._access_key_id, - password=self._access_key_secret, - sslmode="require", - connect_timeout=5, - application_name=f"Dify-{dify_config.project.version}", + # Build connection URL + from urllib.parse import quote_plus + + username = quote_plus(self._access_key_id) + password = quote_plus(self._access_key_secret) + database_url = ( + f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require" ) - # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables - # Connection pool creation success already indicates connectivity + # Create SQLAlchemy engine with connection pool + self._engine = create_engine( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + pool_timeout=30, + connect_args={ + "connect_timeout": 5, + "application_name": f"Dify-{dify_config.project.version}-fixautocommit", + "keepalives": 1, + "keepalives_idle": 60, + "keepalives_interval": 10, + "keepalives_count": 5, + }, + ) self._use_pg_protocol = True logger.info( - "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", + "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)", self.project_name, + pool_size, + pool_recycle, ) return True except Exception as e: - # PG connection failed - fallback to SDK mode self._use_pg_protocol = False - if self._pg_pool: + if self._engine: try: - self._pg_pool.closeall() + self._engine.dispose() except Exception: - logger.debug("Failed to close PG connection pool during cleanup, ignoring") - self._pg_pool = None + logger.debug("Failed to dispose engine during cleanup, ignoring") + self._engine = None - logger.info( - "PG protocol connection failed (region may not support PG protocol): %s. " - "Falling back to SDK mode for read/write operations.", - str(e), - ) - return False - - def _is_connection_valid(self, conn: Any) -> bool: - """ - Check if a connection is still valid. - - Args: - conn: psycopg2 connection object - - Returns: - True if connection is valid, False otherwise - """ - try: - # Check if connection is closed - if conn.closed: - return False - - # Quick ping test - execute a lightweight query - # For SLS PG protocol, we can't use SELECT 1 without FROM, - # so we just check the connection status - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - cursor.fetchone() - return True - except Exception: + logger.debug("Using SDK mode for region: %s", str(e)) return False @contextmanager def _get_connection(self): - """ - Context manager to get a PostgreSQL connection from the pool. + """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically.""" + if not self._engine: + raise RuntimeError("SQLAlchemy engine is not initialized") - Automatically validates and refreshes stale connections. - - Note: Aliyun SLS PG protocol does not support transactions, so we always - use autocommit mode. - - Yields: - psycopg2 connection object - - Raises: - RuntimeError: If PG pool is not initialized - """ - if not self._pg_pool: - raise RuntimeError("PG connection pool is not initialized") - - conn = self._pg_pool.getconn() + connection = self._engine.raw_connection() try: - # Validate connection and get a fresh one if needed - if not self._is_connection_valid(conn): - logger.debug("Connection is stale, marking as bad and getting a new one") - # Mark connection as bad and get a new one - self._pg_pool.putconn(conn, close=True) - conn = self._pg_pool.getconn() - - # Aliyun SLS PG protocol does not support transactions, always use autocommit - conn.autocommit = True - yield conn + connection.autocommit = True # SLS PG protocol does not support transactions + yield connection + except Exception: + raise finally: - # Return connection to pool (or close if it's bad) - if self._is_connection_valid(conn): - self._pg_pool.putconn(conn) - else: - self._pg_pool.putconn(conn, close=True) + connection.close() def close(self) -> None: - """Close the PostgreSQL connection pool.""" - if self._pg_pool: + """Dispose SQLAlchemy engine and close all connections.""" + if self._engine: try: - self._pg_pool.closeall() - logger.info("PG connection pool closed") + self._engine.dispose() + logger.info("SQLAlchemy engine disposed") except Exception: - logger.exception("Failed to close PG connection pool") + logger.exception("Failed to dispose engine") def _is_retriable_error(self, error: Exception) -> bool: - """ - Check if an error is retriable (connection-related issues). - - Args: - error: Exception to check - - Returns: - True if the error is retriable, False otherwise - """ - # Retry on connection-related errors - if isinstance(error, (OperationalError, InterfaceError)): + """Check if error is retriable (connection-related issues).""" + # Check for psycopg2 connection errors directly + if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)): return True - # Check error message for specific connection issues error_msg = str(error).lower() retriable_patterns = [ "connection", @@ -234,34 +151,18 @@ class AliyunLogStorePG: "reset by peer", "no route to host", "network", + "operational error", + "interface error", ] return any(pattern in error_msg for pattern in retriable_patterns) def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: - """ - Write log to SLS using PostgreSQL protocol with automatic retry. - - Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only - writes with log_version field for versioning, same as SDK implementation. - - Args: - logstore: Name of the logstore table - contents: List of (field_name, value) tuples - log_enabled: Whether to enable logging - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff).""" if not contents: return - # Extract field names and values from contents fields = [field_name for field_name, _ in contents] values = [value for _, value in contents] - - # Build INSERT statement with literal values - # Note: Aliyun SLS PG protocol doesn't support parameterized queries, - # so we need to use mogrify to safely create literal values field_list = ", ".join([f'"{field}"' for field in fields]) if log_enabled: @@ -272,67 +173,40 @@ class AliyunLogStorePG: len(contents), ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: - # Use mogrify to safely convert values to SQL literals placeholders = ", ".join(["%s"] * len(fields)) values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' cursor.execute(insert_sql) - # Success - exit retry loop return except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., data validation error), fail immediately - logger.exception( - "Failed to put logs to logstore %s via PG protocol (non-retriable error)", - logstore, - ) + logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed - logger.exception( - "Failed to put logs to logstore %s via PG protocol after %d attempts", - logstore, - max_retries, - ) + logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries) raise def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: - """ - Execute SQL query using PostgreSQL protocol with automatic retry. - - Args: - sql: SQL query string - logstore: Name of the logstore (for logging purposes) - log_enabled: Whether to enable logging - - Returns: - List of result rows as dictionaries - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Execute SQL query with automatic retry (3 attempts with exponential backoff).""" if log_enabled: logger.info( "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", @@ -341,20 +215,16 @@ class AliyunLogStorePG: sql, ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(sql) - - # Get column names from cursor description columns = [desc[0] for desc in cursor.description] - # Fetch all results and convert to list of dicts result = [] for row in cursor.fetchall(): row_dict = {} @@ -372,36 +242,31 @@ class AliyunLogStorePG: return result except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., SQL syntax error), fail immediately logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", + "Failed to execute SQL on logstore %s (non-retriable error): sql=%s", logstore, sql, ) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", + "Failed to execute SQL on logstore %s after %d attempts: sql=%s", logstore, max_retries, sql, ) raise - # This line should never be reached due to raise above, but makes type checker happy return [] diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py index e69de29bb2..b5a4fcf844 100644 --- a/api/extensions/logstore/repositories/__init__.py +++ b/api/extensions/logstore/repositories/__init__.py @@ -0,0 +1,29 @@ +""" +LogStore repository utilities. +""" + +from typing import Any + + +def safe_float(value: Any, default: float = 0.0) -> float: + """ + Safely convert a value to float, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + +def safe_int(value: Any, default: int = 0) -> int: + """ + Safely convert a value to int, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 8c804d6bb5..f67723630b 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -14,6 +14,8 @@ from typing import Any from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value from models.workflow import WorkflowNodeExecutionModel from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.index = int(data.get("index", 0)) - model.elapsed_time = float(data.get("elapsed_time", 0)) + model.index = safe_int(data.get("index", 0)) + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) # Optional fields model.workflow_run_id = data.get("workflow_run_id") @@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep node_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_id = escape_identifier(workflow_id) + escaped_node_id = escape_identifier(node_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_id = '{workflow_id}' - AND node_id = '{node_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_id = '{escaped_workflow_id}' + AND node_id = '{escaped_node_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 @@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep else: # Use SDK with LogStore query syntax query = ( - f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}" ) from_time = 0 to_time = int(time.time()) # now @@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep workflow_run_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_run_id = escape_identifier(workflow_run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_run_id = '{workflow_run_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_run_id = '{escaped_workflow_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1000 @@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax - query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" + query = ( + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_run_id: {escaped_workflow_run_id}" + ) from_time = 0 to_time = int(time.time()) # now @@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep """ logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) try: + # Escape parameters to prevent SQL injection + escaped_execution_id = escape_identifier(execution_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) - tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" + if tenant_id: + escaped_tenant_id = escape_identifier(tenant_id) + tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'" + else: + tenant_filter = "" + sql_query = f""" SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 + WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1 """ @@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax + # Note: Values must be quoted in LogStore query syntax to prevent injection if tenant_id: - query = f"id: {execution_id} and tenant_id: {tenant_id}" + query = ( + f"id:{escape_logstore_query_value(execution_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)}" + ) else: - query = f"id: {execution_id}" + query = f"id:{escape_logstore_query_value(execution_id)}" from_time = 0 to_time = int(time.time()) # now diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 252cdcc4df..14382ed876 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -10,6 +10,7 @@ Key Features: - Optimized deduplication using finished_at IS NOT NULL filter - Window functions only when necessary (running status queries) - Multi-tenant data isolation and security +- SQL injection prevention via parameter escaping """ import logging @@ -22,6 +23,8 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowRun @@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.total_tokens = int(data.get("total_tokens", 0)) - model.total_steps = int(data.get("total_steps", 0)) - model.exceptions_count = int(data.get("exceptions_count", 0)) + model.total_tokens = safe_int(data.get("total_tokens", 0)) + model.total_steps = safe_int(data.get("total_steps", 0)) + model.exceptions_count = safe_int(data.get("exceptions_count", 0)) # Optional fields model.graph = data.get("graph") @@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: if model.finished_at and model.created_at: model.elapsed_time = (model.finished_at - model.created_at).total_seconds() else: - model.elapsed_time = float(data.get("elapsed_time", 0)) + # Use safe conversion to handle 'null' strings and None values + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) return model @@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): status, ) # Convert triggered_from to list if needed - if isinstance(triggered_from, WorkflowRunTriggeredFrom): + if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)): triggered_from_list = [triggered_from] else: triggered_from_list = list(triggered_from) - # Build triggered_from filter - triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) - # Build status filter - status_filter = f"AND status='{status}'" if status else "" + # Build triggered_from filter with escaped values + # Support both enum and string values for triggered_from + triggered_from_filter = " OR ".join( + [ + f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'" + for tf in triggered_from_list + ] + ) + + # Build status filter with escaped value + status_filter = f"AND status='{escape_sql_string(status)}'" if status else "" # Build last_id filter for pagination # Note: This is simplified. In production, you'd need to track created_at from last record @@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' AND ({triggered_from_filter}) {status_filter} {last_id_filter} @@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) try: + # Escape parameters to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' + AND tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" + # Note: Values must be quoted in LogStore query syntax to prevent injection + query = ( + f"id:{escape_logstore_query_value(run_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)} " + f"and app_id:{escape_logstore_query_value(app_id)}" + ) from_time = 0 to_time = int(time.time()) # now @@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) try: + # Escape parameter to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id}" + # Note: Values must be quoted in LogStore query syntax + query = f"id:{escape_logstore_query_value(run_id)}" from_time = 0 to_time = int(time.time()) # now @@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): triggered_from, status, ) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + # Build time range filter time_filter = "" if time_range: @@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # If status is provided, simple count if status: + escaped_status = escape_sql_string(status) + if status == "running": # Running status requires window function sql = f""" @@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' - AND status='{status}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' + AND status='{escaped_status}' AND finished_at IS NOT NULL {time_filter} """ @@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # No status filter - get counts grouped by status # Use optimized query for finished runs, separate query for running try: + # Escape parameters (already escaped above, reuse variables) # Count finished runs grouped by status finished_sql = f""" SELECT status, COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY status @@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug( "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): created_by, COUNT(DISTINCT id) AS interactions FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date, created_by diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 1119534d52..9928879a7b 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.entities import WorkflowExecution from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( @@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) -def to_serializable(obj): - """ - Convert non-JSON-serializable objects into JSON-compatible formats. - - - Uses `to_dict()` if it's a callable method. - - Falls back to string representation. - """ - if hasattr(obj, "to_dict") and callable(obj.to_dict): - return obj.to_dict() - return str(obj) - - class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): def __init__( self, @@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" # Control flag for whether to write the `graph` field to LogStore. # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; @@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Generate log_version as nanosecond timestamp for record versioning log_version = str(time.time_ns()) + # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.) + json_converter = WorkflowRuntimeTypeConverter() + logstore_model = [ ("id", domain_model.id_), ("log_version", log_version), # Add log_version field for append-only writes @@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ("version", domain_model.workflow_version), ( "graph", - json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False) if domain_model.graph and self._enable_put_graph_field else "{}", ), ( "inputs", - json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False) if domain_model.inputs else "{}", ), ( "outputs", - json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False) if domain_model.outputs else "{}", ), diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 400a089516..4897171b12 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -24,6 +24,8 @@ from core.workflow.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier from libs.helper import extract_tenant_id from models import ( Account, @@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut node_execution_id=data.get("node_execution_id"), workflow_id=data.get("workflow_id", ""), workflow_execution_id=data.get("workflow_run_id"), - index=int(data.get("index", 0)), + index=safe_int(data.get("index", 0)), predecessor_node_id=data.get("predecessor_node_id"), node_id=data.get("node_id", ""), node_type=NodeType(data.get("node_type", "start")), @@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut outputs=outputs, status=status, error=data.get("error"), - elapsed_time=float(data.get("elapsed_time", 0.0)), + elapsed_time=safe_float(data.get("elapsed_time", 0.0)), metadata=domain_metadata, created_at=created_at, finished_at=finished_at, @@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: logger.debug( @@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): Save or update the inputs, process_data, or outputs associated with a specific node_execution record. - For LogStore implementation, this is similar to save() since we always write - complete records. We append a new record with updated data fields. + For LogStore implementation, this is a no-op for the LogStore write because save() + already writes all fields including inputs, process_data, and outputs. The caller + typically calls save() first to persist status/metadata, then calls save_execution_data() + to persist data fields. Since LogStore writes complete records atomically, we don't + need a separate write here to avoid duplicate records. + + However, if dual-write is enabled, we still need to call the SQL repository's + save_execution_data() method to properly update the SQL database. Args: execution: The NodeExecution instance with data to save """ - logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) - # In LogStore, we simply write a new complete record with the data - # The log_version timestamp will ensure this is treated as the latest version - self.save(execution) + logger.debug( + "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s", + execution.id, + execution.node_execution_id, + ) + # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs + # Calling save() again would create a duplicate record in the append-only LogStore + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save_execution_data(execution) + logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id) + except Exception: + logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) + # Don't raise - LogStore write succeeded, SQL is just a backup def get_by_workflow_run( self, @@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. - Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. - This ensures we only get the final version of each node execution. + Uses LogStore SQL query with window function to get the latest version of each node execution. + This ensures we only get the most recent version of each node execution record. Args: workflow_run_id: The workflow run ID order_config: Optional configuration for ordering results @@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): A list of NodeExecution instances Note: - This method filters by finished_at IS NOT NULL to avoid duplicates from - version updates. For complete history including intermediate states, - a different query strategy would be needed. + This method uses ROW_NUMBER() window function partitioned by node_execution_id + to get the latest version (highest log_version) of each node execution. """ logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) - # Build SQL query with deduplication using finished_at IS NOT NULL - # This optimization avoids window functions for common case where we only - # want the final state of each node execution + # Build SQL query with deduplication using window function + # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) + # ensures we get the latest version of each node execution - # Build ORDER BY clause + # Escape parameters to prevent SQL injection + escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_tenant_id = escape_identifier(self._tenant_id) + + # Build ORDER BY clause for outer query order_clause = "" if order_config and order_config.order_by: order_fields = [] @@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): if order_fields: order_clause = "ORDER BY " + ", ".join(order_fields) - sql = f""" - SELECT * - FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{workflow_run_id}' - AND tenant_id='{self._tenant_id}' - AND finished_at IS NOT NULL - """ - + # Build app_id filter for subquery + app_id_filter = "" if self._app_id: - sql += f" AND app_id='{self._app_id}'" + escaped_app_id = escape_identifier(self._app_id) + app_id_filter = f" AND app_id='{escaped_app_id}'" + + # Use window function to get latest version of each node execution + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_node_execution_logstore} + WHERE workflow_run_id='{escaped_workflow_run_id}' + AND tenant_id='{escaped_tenant_id}' + {app_id_filter} + ) t + WHERE rn = 1 + """ if order_clause: sql += f" {order_clause}" diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py new file mode 100644 index 0000000000..d88d6bd959 --- /dev/null +++ b/api/extensions/logstore/sql_escape.py @@ -0,0 +1,134 @@ +""" +SQL Escape Utility for LogStore Queries + +This module provides escaping utilities to prevent injection attacks in LogStore queries. + +LogStore supports two query modes: +1. PG Protocol Mode: Uses SQL syntax with single quotes for strings +2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes + +Key Security Concerns: +- Prevent tenant A from accessing tenant B's data via injection +- SLS queries are read-only, so we focus on data access control +- Different escaping strategies for SQL vs LogStore query syntax +""" + + +def escape_sql_string(value: str) -> str: + """ + Escape a string value for safe use in SQL queries. + + This function escapes single quotes by doubling them, which is the standard + SQL escaping method. This prevents SQL injection by ensuring that user input + cannot break out of string literals. + + Args: + value: The string value to escape + + Returns: + Escaped string safe for use in SQL queries + + Examples: + >>> escape_sql_string("normal_value") + "normal_value" + >>> escape_sql_string("value' OR '1'='1") + "value'' OR ''1''=''1" + >>> escape_sql_string("tenant's_id") + "tenant''s_id" + + Security: + - Prevents breaking out of string literals + - Stops injection attacks like: ' OR '1'='1 + - Protects against cross-tenant data access + """ + if not value: + return value + + # Escape single quotes by doubling them (standard SQL escaping) + # This prevents breaking out of string literals in SQL queries + return value.replace("'", "''") + + +def escape_identifier(value: str) -> str: + """ + Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use. + + This function is for PG protocol mode (SQL syntax). + For SDK mode, use escape_logstore_query_value() instead. + + Args: + value: The identifier value to escape + + Returns: + Escaped identifier safe for use in SQL queries + + Examples: + >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000") + "550e8400-e29b-41d4-a716-446655440000" + >>> escape_identifier("tenant_id' OR '1'='1") + "tenant_id'' OR ''1''=''1" + + Security: + - Prevents SQL injection via identifiers + - Stops cross-tenant access attempts + - Works for UUIDs, alphanumeric IDs, and similar identifiers + """ + # For identifiers, use the same escaping as strings + # This is simple and effective for preventing injection + return escape_sql_string(value) + + +def escape_logstore_query_value(value: str) -> str: + """ + Escape value for LogStore query syntax (SDK mode). + + LogStore query syntax rules: + 1. Keywords (and/or/not) are case-insensitive + 2. Single quotes are ordinary characters (no special meaning) + 3. Double quotes wrap values: key:"value" + 4. Backslash is the escape character: + - \" for double quote inside value + - \\ for backslash itself + 5. Parentheses can change query structure + + To prevent injection: + - Wrap value in double quotes to treat special chars as literals + - Escape backslashes and double quotes using backslash + + Args: + value: The value to escape for LogStore query syntax + + Returns: + Quoted and escaped value safe for LogStore query syntax (includes the quotes) + + Examples: + >>> escape_logstore_query_value("normal_value") + '"normal_value"' + >>> escape_logstore_query_value("value or field:evil") + '"value or field:evil"' # 'or' and ':' are now literals + >>> escape_logstore_query_value('value"test') + '"value\\"test"' # Internal double quote escaped + >>> escape_logstore_query_value('value\\test') + '"value\\\\test"' # Backslash escaped + + Security: + - Prevents injection via and/or/not keywords + - Prevents injection via colons (:) + - Prevents injection via parentheses + - Protects against cross-tenant data access + + Note: + Escape order is critical: backslash first, then double quotes. + Otherwise, we'd double-escape the escape character itself. + """ + if not value: + return '""' + + # IMPORTANT: Escape backslashes FIRST, then double quotes + # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly) + escaped = value.replace("\\", "\\\\") # \ -> \\ + escaped = escaped.replace('"', '\\"') # " -> \" + + # Wrap in double quotes to treat as literal string + # This prevents and/or/not/:/() from being interpreted as operators + return f'"{escaped}"' diff --git a/api/tests/unit_tests/extensions/logstore/__init__.py b/api/tests/unit_tests/extensions/logstore/__init__.py new file mode 100644 index 0000000000..fe9ada9128 --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/__init__.py @@ -0,0 +1 @@ +"""LogStore extension unit tests.""" diff --git a/api/tests/unit_tests/extensions/logstore/test_sql_escape.py b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py new file mode 100644 index 0000000000..63172b3f9b --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py @@ -0,0 +1,469 @@ +""" +Unit tests for SQL escape utility functions. + +These tests ensure that SQL injection attacks are properly prevented +in LogStore queries, particularly for cross-tenant access scenarios. +""" + +import pytest + +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string + + +class TestEscapeSQLString: + """Test escape_sql_string function.""" + + def test_escape_empty_string(self): + """Test escaping empty string.""" + assert escape_sql_string("") == "" + + def test_escape_normal_string(self): + """Test escaping string without special characters.""" + assert escape_sql_string("tenant_abc123") == "tenant_abc123" + assert escape_sql_string("app-uuid-1234") == "app-uuid-1234" + + def test_escape_single_quote(self): + """Test escaping single quote.""" + # Single quote should be doubled + assert escape_sql_string("tenant'id") == "tenant''id" + assert escape_sql_string("O'Reilly") == "O''Reilly" + + def test_escape_multiple_quotes(self): + """Test escaping multiple single quotes.""" + assert escape_sql_string("a'b'c") == "a''b''c" + assert escape_sql_string("'''") == "''''''" + + # === SQL Injection Attack Scenarios === + + def test_prevent_boolean_injection(self): + """Test prevention of boolean injection attacks.""" + # Classic OR 1=1 attack + malicious_input = "tenant' OR '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR ''1''=''1" + + # When used in SQL, this becomes a safe string literal + sql = f"WHERE tenant_id='{escaped}'" + assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'" + # The entire input is now a string literal that won't match any tenant + + def test_prevent_or_injection(self): + """Test prevention of OR-based injection.""" + malicious_input = "tenant_a' OR tenant_id='tenant_b" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant_a'' OR tenant_id=''tenant_b" + + sql = f"WHERE tenant_id='{escaped}'" + # The OR is now part of the string literal, not SQL logic + assert "OR tenant_id=" in sql + # The SQL has: opening ', doubled internal quotes '', and closing ' + assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'" + + def test_prevent_union_injection(self): + """Test prevention of UNION-based injection.""" + malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1" + + # UNION becomes part of the string literal + assert "UNION" in escaped + assert escaped.count("''") == 4 # All internal quotes are doubled + + def test_prevent_comment_injection(self): + """Test prevention of comment-based injection.""" + # SQL comment to bypass remaining conditions + malicious_input = "tenant' --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' --" + + sql = f"WHERE tenant_id='{escaped}' AND deleted=false" + # The -- is now inside the string, not a SQL comment + assert "--" in sql + assert "AND deleted=false" in sql # This part is NOT commented out + + def test_prevent_semicolon_injection(self): + """Test prevention of semicolon-based multi-statement injection.""" + malicious_input = "tenant'; DROP TABLE users; --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant''; DROP TABLE users; --" + + # Semicolons and DROP are now part of the string + assert "DROP TABLE" in escaped + + def test_prevent_time_based_blind_injection(self): + """Test prevention of time-based blind SQL injection.""" + malicious_input = "tenant' AND SLEEP(5) --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' AND SLEEP(5) --" + + # SLEEP becomes part of the string + assert "SLEEP" in escaped + + def test_prevent_wildcard_injection(self): + """Test prevention of wildcard-based injection.""" + malicious_input = "tenant' OR tenant_id LIKE '%" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR tenant_id LIKE ''%" + + # The LIKE and wildcard are now part of the string + assert "LIKE" in escaped + + def test_prevent_null_byte_injection(self): + """Test handling of null bytes.""" + # Null bytes can sometimes bypass filters + malicious_input = "tenant\x00' OR '1'='1" + escaped = escape_sql_string(malicious_input) + # Null byte is preserved, but quote is escaped + assert "''1''=''1" in escaped + + # === Real-world SAAS Scenarios === + + def test_cross_tenant_access_attempt(self): + """Test prevention of cross-tenant data access.""" + # Attacker tries to access another tenant's data + attacker_input = "tenant_b' OR tenant_id='tenant_a" + escaped = escape_sql_string(attacker_input) + + sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'" + # The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a" + # which doesn't exist - preventing access to either tenant's data + assert "tenant_b'' OR tenant_id=''tenant_a" in sql + + def test_cross_app_access_attempt(self): + """Test prevention of cross-application data access.""" + attacker_input = "app1' OR app_id='app2" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE app_id='{escaped}'" + # Cannot access app2's data + assert "app1'' OR app_id=''app2" in sql + + def test_bypass_status_filter(self): + """Test prevention of bypassing status filters.""" + # Try to see all statuses instead of just 'running' + attacker_input = "running' OR status LIKE '%" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE status='{escaped}'" + # Status condition is not bypassed + assert "running'' OR status LIKE ''%" in sql + + # === Edge Cases === + + def test_escape_only_quotes(self): + """Test string with only quotes.""" + assert escape_sql_string("'") == "''" + assert escape_sql_string("''") == "''''" + + def test_escape_mixed_content(self): + """Test string with mixed quotes and other chars.""" + input_str = "It's a 'test' of O'Reilly's code" + escaped = escape_sql_string(input_str) + assert escaped == "It''s a ''test'' of O''Reilly''s code" + + def test_escape_unicode_with_quotes(self): + """Test Unicode strings with quotes.""" + input_str = "租户' OR '1'='1" + escaped = escape_sql_string(input_str) + assert escaped == "租户'' OR ''1''=''1" + + +class TestEscapeIdentifier: + """Test escape_identifier function.""" + + def test_escape_uuid(self): + """Test escaping UUID identifiers.""" + uuid = "550e8400-e29b-41d4-a716-446655440000" + assert escape_identifier(uuid) == uuid + + def test_escape_alphanumeric_id(self): + """Test escaping alphanumeric identifiers.""" + assert escape_identifier("tenant_123") == "tenant_123" + assert escape_identifier("app-abc-123") == "app-abc-123" + + def test_escape_identifier_with_quote(self): + """Test escaping identifier with single quote.""" + malicious = "tenant' OR '1'='1" + escaped = escape_identifier(malicious) + assert escaped == "tenant'' OR ''1''=''1" + + def test_identifier_injection_attempt(self): + """Test prevention of injection through identifiers.""" + # Common identifier injection patterns + test_cases = [ + ("id' OR '1'='1", "id'' OR ''1''=''1"), + ("id'; DROP TABLE", "id''; DROP TABLE"), + ("id' UNION SELECT", "id'' UNION SELECT"), + ] + + for malicious, expected in test_cases: + assert escape_identifier(malicious) == expected + + +class TestSQLInjectionIntegration: + """Integration tests simulating real SQL construction scenarios.""" + + def test_complete_where_clause_safety(self): + """Test that a complete WHERE clause is safe from injection.""" + # Simulating typical query construction + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + run_id = "run' --" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + escaped_run = escape_identifier(run_id) + + sql = f""" + SELECT * FROM workflow_runs + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND id='{escaped_run}' + """ + + # Verify all special characters are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + assert "run'' --" in sql + + # Verify SQL structure is preserved (3 conditions with AND) + assert sql.count("AND") == 2 + + def test_multiple_conditions_with_injection_attempts(self): + """Test multiple conditions all attempting injection.""" + conditions = { + "tenant_id": "t1' OR tenant_id='t2", + "app_id": "a1' OR app_id='a2", + "status": "running' OR '1'='1", + } + + where_parts = [] + for field, value in conditions.items(): + escaped = escape_sql_string(value) + where_parts.append(f"{field}='{escaped}'") + + where_clause = " AND ".join(where_parts) + + # All injection attempts are neutralized + assert "t1'' OR tenant_id=''t2" in where_clause + assert "a1'' OR app_id=''a2" in where_clause + assert "running'' OR ''1''=''1" in where_clause + + # AND structure is preserved + assert where_clause.count(" AND ") == 2 + + @pytest.mark.parametrize( + ("attack_vector", "description"), + [ + ("' OR '1'='1", "Boolean injection"), + ("' OR '1'='1' --", "Boolean with comment"), + ("' UNION SELECT * FROM users --", "Union injection"), + ("'; DROP TABLE workflow_runs; --", "Destructive command"), + ("' AND SLEEP(10) --", "Time-based blind"), + ("' OR tenant_id LIKE '%", "Wildcard injection"), + ("admin' --", "Comment bypass"), + ("' OR 1=1 LIMIT 1 --", "Limit bypass"), + ], + ) + def test_common_injection_vectors(self, attack_vector, description): + """Test protection against common injection attack vectors.""" + escaped = escape_sql_string(attack_vector) + + # Build SQL + sql = f"WHERE tenant_id='{escaped}'" + + # Verify the attack string is now a safe literal + # The key indicator: all internal single quotes are doubled + internal_quotes = escaped.count("''") + original_quotes = attack_vector.count("'") + + # Each original quote should be doubled + assert internal_quotes == original_quotes + + # Verify SQL has exactly 2 quotes (opening and closing) + assert sql.count("'") >= 2 # At least opening and closing + + def test_logstore_specific_scenario(self): + """Test SQL injection prevention in LogStore-specific scenarios.""" + # Simulate LogStore query with window function + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM workflow_execution_logstore + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + """ + + # Complex query structure is maintained + assert "ROW_NUMBER()" in sql + assert "PARTITION BY id" in sql + + # Injection attempts are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + + +# ==================================================================================== +# Tests for LogStore Query Syntax (SDK Mode) +# ==================================================================================== + + +class TestLogStoreQueryEscape: + """Test escape_logstore_query_value for SDK mode query syntax.""" + + def test_normal_value(self): + """Test escaping normal alphanumeric value.""" + value = "550e8400-e29b-41d4-a716-446655440000" + escaped = escape_logstore_query_value(value) + + # Should be wrapped in double quotes + assert escaped == '"550e8400-e29b-41d4-a716-446655440000"' + + def test_empty_value(self): + """Test escaping empty string.""" + assert escape_logstore_query_value("") == '""' + + def test_value_with_and_keyword(self): + """Test that 'and' keyword is neutralized when quoted.""" + malicious = "value and field:evil" + escaped = escape_logstore_query_value(malicious) + + # Should be wrapped in quotes, making 'and' a literal + assert escaped == '"value and field:evil"' + + # Simulate using in query + query = f"tenant_id:{escaped}" + assert query == 'tenant_id:"value and field:evil"' + + def test_value_with_or_keyword(self): + """Test that 'or' keyword is neutralized when quoted.""" + malicious = "tenant_a or tenant_id:tenant_b" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"tenant_a or tenant_id:tenant_b"' + + query = f"tenant_id:{escaped}" + assert "or" in query # Present but as literal string + + def test_value_with_not_keyword(self): + """Test that 'not' keyword is neutralized when quoted.""" + malicious = "not field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"not field:value"' + + def test_value_with_parentheses(self): + """Test that parentheses are neutralized when quoted.""" + malicious = "(tenant_a or tenant_b)" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"(tenant_a or tenant_b)"' + assert "(" in escaped # Present as literal + assert ")" in escaped # Present as literal + + def test_value_with_colon(self): + """Test that colons are neutralized when quoted.""" + malicious = "field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"field:value"' + assert ":" in escaped # Present as literal + + def test_value_with_double_quotes(self): + """Test that internal double quotes are escaped.""" + value_with_quotes = 'tenant"test"value' + escaped = escape_logstore_query_value(value_with_quotes) + + # Double quotes should be escaped with backslash + assert escaped == '"tenant\\"test\\"value"' + # Should have outer quotes plus escaped inner quotes + assert '\\"' in escaped + + def test_value_with_backslash(self): + """Test that backslashes are escaped.""" + value_with_backslash = "tenant\\test" + escaped = escape_logstore_query_value(value_with_backslash) + + # Backslash should be escaped + assert escaped == '"tenant\\\\test"' + assert "\\\\" in escaped + + def test_value_with_backslash_and_quote(self): + """Test escaping both backslash and double quote.""" + value = 'path\\to\\"file"' + escaped = escape_logstore_query_value(value) + + # Both should be escaped + assert escaped == '"path\\\\to\\\\\\"file\\""' + # Verify escape order is correct + assert "\\\\" in escaped # Escaped backslash + assert '\\"' in escaped # Escaped double quote + + def test_complex_injection_attempt(self): + """Test complex injection combining multiple operators.""" + malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")' + escaped = escape_logstore_query_value(malicious) + + # All special chars should be literals or escaped + assert escaped.startswith('"') + assert escaped.endswith('"') + # Inner double quotes escaped, operators become literals + assert "or" in escaped + assert "and" in escaped + assert '\\"' in escaped # Escaped quotes + + def test_only_backslash(self): + """Test escaping a single backslash.""" + assert escape_logstore_query_value("\\") == '"\\\\"' + + def test_only_double_quote(self): + """Test escaping a single double quote.""" + assert escape_logstore_query_value('"') == '"\\""' + + def test_multiple_backslashes(self): + """Test escaping multiple consecutive backslashes.""" + assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6 + + def test_escape_sequence_like_input(self): + """Test that existing escape sequences are properly escaped.""" + # Input looks like already escaped, but we still escape it + value = 'value\\"test' + escaped = escape_logstore_query_value(value) + # \\ -> \\\\, " -> \" + assert escaped == '"value\\\\\\"test"' + + +@pytest.mark.parametrize( + ("attack_scenario", "field", "malicious_value"), + [ + ("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"), + ("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"), + ("Boolean logic", "status", "succeeded or status:failed"), + ("Negation", "tenant_id", "not tenant_a"), + ("Field injection", "run_id", "run123 and tenant_id:evil_tenant"), + ("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"), + ("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'), + ("Backslash escape bypass", "app_id", "app\\ and app_id:evil"), + ], +) +def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str): + """Test that various LogStore query injection attempts are neutralized.""" + escaped = escape_logstore_query_value(malicious_value) + + # Build query + query = f"{field}:{escaped}" + + # All operators should be within quoted string (literals) + assert escaped.startswith('"') + assert escaped.endswith('"') + + # Verify the full query structure is safe + assert query.count(":") >= 1 # At least the main field:value separator diff --git a/docker/.env.example b/docker/.env.example index e7cb8711ce..9a3a7239c6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Options: # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository # Core workflow node execution repository implementation # Options: # - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) # - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository # API workflow run repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository # API workflow node execution repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository # Workflow log cleanup configuration