From 81ef7343d46042e469fe7eafe0a7f3ee7245c716 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 5 Sep 2025 14:00:20 +0800 Subject: [PATCH] chore: (trigger) refactor webhook service (#25229) --- api/services/plugin/plugin_service.py | 1 - api/services/tools/tools_transform_service.py | 5 +- .../trigger/trigger_provider_service.py | 4 +- api/services/webhook_service.py | 635 ++++++++++-------- 4 files changed, 354 insertions(+), 291 deletions(-) diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 0a50f78f79..69414ba7cc 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -184,7 +184,6 @@ class PluginService: ) return str(url_prefix % {"tenant_id": tenant_id, "filename": filename}) - @staticmethod def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]: """ diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 50269d7617..d6da32eeba 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -31,7 +31,6 @@ logger = logging.getLogger(__name__) class ToolTransformService: - @classmethod def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]: """ @@ -69,9 +68,7 @@ class ToolTransformService: elif isinstance(provider, ToolProviderApiEntity): if provider.plugin_id: if isinstance(provider.icon, str): - provider.icon = PluginService.get_plugin_icon_url( - tenant_id=tenant_id, filename=provider.icon - ) + provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon) if isinstance(provider.icon_dark, str) and provider.icon_dark: provider.icon_dark = PluginService.get_plugin_icon_url( tenant_id=tenant_id, filename=provider.icon_dark diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 5570b46062..95794e0b8e 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -175,9 +175,7 @@ class TriggerProviderService: if not db_provider: raise ValueError(f"Trigger provider subscription {subscription_id} not found") - provider_controller = TriggerManager.get_trigger_provider( - tenant_id, TriggerProviderID(db_provider.provider_id) - ) + provider_controller = TriggerManager.get_trigger_provider(tenant_id, TriggerProviderID(db_provider.provider_id)) # Clear cache _, cache = create_trigger_provider_encrypter_for_subscription( tenant_id=tenant_id, diff --git a/api/services/webhook_service.py b/api/services/webhook_service.py index 1d831e8d24..38d47be742 100644 --- a/api/services/webhook_service.py +++ b/api/services/webhook_service.py @@ -1,5 +1,6 @@ import json import logging +import mimetypes from collections.abc import Mapping from typing import Any @@ -76,13 +77,7 @@ class WebhookService: @classmethod def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]: """Extract and process data from incoming webhook request.""" - - content_length = request.content_length - if content_length and content_length > dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE: - raise RequestEntityTooLarge( - f"Webhook request too large: {content_length} bytes exceeds maximum allowed size \ - of {dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE} bytes" - ) + cls._validate_content_length() data = { "method": request.method, @@ -92,63 +87,89 @@ class WebhookService: "files": {}, } - content_type = request.headers.get("Content-Type", "").lower() + # Extract and normalize content type + content_type = cls._extract_content_type(request.headers) - # Extract body data based on content type - if "application/json" in content_type: - try: - data["body"] = request.get_json() or {} - except Exception: - data["body"] = {} - elif "application/x-www-form-urlencoded" in content_type: - data["body"] = dict(request.form) - elif "multipart/form-data" in content_type: - data["body"] = dict(request.form) - # Handle file uploads - if request.files: - data["files"] = cls._process_file_uploads(request.files, webhook_trigger) - elif "application/octet-stream" in content_type: - # Binary data - process as file using ToolFileManager - try: - file_content = request.get_data() - if file_content: - tool_file_manager = ToolFileManager() + # Route to appropriate extractor based on content type + extractors = { + "application/json": cls._extract_json_body, + "application/x-www-form-urlencoded": cls._extract_form_body, + "multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger), + "application/octet-stream": lambda: cls._extract_octet_stream_body(webhook_trigger), + "text/plain": cls._extract_text_body, + } - # Create file using ToolFileManager - tool_file = tool_file_manager.create_file_by_raw( - user_id=webhook_trigger.created_by, - tenant_id=webhook_trigger.tenant_id, - conversation_id=None, - file_binary=file_content, - mimetype="application/octet-stream", - ) + extractor = extractors.get(content_type) + if not extractor: + # Default to text/plain for unknown content types + logger.warning("Unknown Content-Type: %s, treating as text/plain", content_type) + extractor = cls._extract_text_body - # Build File object - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE.value, - } - file_obj = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=webhook_trigger.tenant_id, - ) - data["body"] = {"raw": file_obj.to_dict()} - else: - data["body"] = {"raw": None} - except Exception: - logger.exception("Failed to process octet-stream data") - data["body"] = {"raw": None} - elif "text/plain" in content_type: - # Text data - store as raw string - try: - data["body"] = {"raw": request.get_data(as_text=True)} - except Exception: - data["body"] = {"raw": ""} - else: - raise ValueError(f"Unsupported Content-Type: {content_type}") + # Extract body and files + body_data, files_data = extractor() + data["body"] = body_data + data["files"] = files_data return data + @classmethod + def _validate_content_length(cls) -> None: + """Validate request content length against maximum allowed size.""" + content_length = request.content_length + if content_length and content_length > dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE: + raise RequestEntityTooLarge( + f"Webhook request too large: {content_length} bytes exceeds maximum allowed size " + f"of {dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE} bytes" + ) + + @classmethod + def _extract_json_body(cls) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract JSON body from request.""" + try: + body = request.get_json() or {} + except Exception: + logger.warning("Failed to parse JSON body") + body = {} + return body, {} + + @classmethod + def _extract_form_body(cls) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract form-urlencoded body from request.""" + return dict(request.form), {} + + @classmethod + def _extract_multipart_body(cls, webhook_trigger: WorkflowWebhookTrigger) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract multipart/form-data body and files from request.""" + body = dict(request.form) + files = cls._process_file_uploads(request.files, webhook_trigger) if request.files else {} + return body, files + + @classmethod + def _extract_octet_stream_body( + cls, webhook_trigger: WorkflowWebhookTrigger + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract binary data as file from request.""" + try: + file_content = request.get_data() + if file_content: + file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger) + return {"raw": file_obj.to_dict()}, {} + else: + return {"raw": None}, {} + except Exception: + logger.exception("Failed to process octet-stream data") + return {"raw": None}, {} + + @classmethod + def _extract_text_body(cls) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract text/plain body from request.""" + try: + body = {"raw": request.get_data(as_text=True)} + except Exception: + logger.warning("Failed to extract text body") + body = {"raw": ""} + return body, {} + @classmethod def _process_file_uploads(cls, files, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]: """Process file uploads using ToolFileManager.""" @@ -157,246 +178,288 @@ class WebhookService: for name, file in files.items(): if file and file.filename: try: - tool_file_manager = ToolFileManager() file_content = file.read() - - # Create file using ToolFileManager - tool_file = tool_file_manager.create_file_by_raw( - user_id=webhook_trigger.created_by, - tenant_id=webhook_trigger.tenant_id, - conversation_id=None, - file_binary=file_content, - mimetype=file.content_type or "application/octet-stream", - ) - - # Build File object - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE.value, - } - file_obj = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=webhook_trigger.tenant_id, - ) + mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream" + file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) processed_files[name] = file_obj.to_dict() - except Exception: - logger.exception("Failed to process file upload %s", name) + logger.exception("Failed to process file upload '%s'", name) # Continue processing other files return processed_files + @classmethod + def _create_file_from_binary( + cls, file_content: bytes, mimetype: str, webhook_trigger: WorkflowWebhookTrigger + ) -> Any: + """Create a file object from binary content using ToolFileManager.""" + tool_file_manager = ToolFileManager() + + # Create file using ToolFileManager + tool_file = tool_file_manager.create_file_by_raw( + user_id=webhook_trigger.created_by, + tenant_id=webhook_trigger.tenant_id, + conversation_id=None, + file_binary=file_content, + mimetype=mimetype, + ) + + # Build File object + mapping = { + "tool_file_id": tool_file.id, + "transfer_method": FileTransferMethod.TOOL_FILE.value, + } + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=webhook_trigger.tenant_id, + ) + @classmethod def validate_webhook_request(cls, webhook_data: dict[str, Any], node_config: Mapping[str, Any]) -> dict[str, Any]: """Validate webhook request against node configuration.""" - try: - node_data = node_config.get("data", {}) + if node_config is None: + return cls._validation_error("Validation failed: Invalid node configuration") - # Validate HTTP method - configured_method = node_data.get("method", "get").upper() - request_method = webhook_data["method"].upper() - if configured_method != request_method: - return { - "valid": False, - "error": f"HTTP method mismatch. Expected {configured_method}, got {request_method}", - } + node_data = node_config.get("data", {}) - # Validate Content-type - configured_content_type = node_data.get("content_type", "application/json").lower() - request_content_type = webhook_data["headers"].get("Content-Type", "").lower() - if not request_content_type: - request_content_type = webhook_data["headers"].get("content-type", "application/json").lower() + # Early validation of HTTP method and content-type + validation_result = cls._validate_http_metadata(webhook_data, node_data) + if not validation_result["valid"]: + return validation_result - # Extract the main content type (ignore parameters like boundary) - request_content_type = request_content_type.split(";")[0].strip() + # Validate headers and query params + validation_result = cls._validate_headers_and_params(webhook_data, node_data) + if not validation_result["valid"]: + return validation_result - if configured_content_type != request_content_type: - return { - "valid": False, - "error": f"Content-type mismatch. Expected {configured_content_type}, got {request_content_type}", - } + # Validate body based on content type + configured_content_type = node_data.get("content_type", "application/json").lower() + return cls._validate_body_by_content_type(webhook_data, node_data, configured_content_type) - # Validate required headers (case-insensitive) - headers = node_data.get("headers", []) - # Create case-insensitive header lookup - webhook_headers_lower = {k.lower(): v for k, v in webhook_data["headers"].items()} - for header in headers: - if header.get("required", False): - header_name = header.get("name", "") - if header_name.lower() not in webhook_headers_lower: - return {"valid": False, "error": f"Required header missing: {header_name}"} + @classmethod + def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate HTTP method and content-type.""" + # Validate HTTP method + configured_method = node_data.get("method", "get").upper() + request_method = webhook_data["method"].upper() + if configured_method != request_method: + return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}") - # Validate required query parameters - params = node_data.get("params", []) - for param in params: - if param.get("required", False): - param_name = param.get("name", "") - if param_name not in webhook_data["query_params"]: - return {"valid": False, "error": f"Required query parameter missing: {param_name}"} + # Validate Content-type + configured_content_type = node_data.get("content_type", "application/json").lower() + request_content_type = cls._extract_content_type(webhook_data["headers"]) - if configured_content_type == "text/plain": - # For text/plain, just validate that we have a body if any body params are configured as required - body_params = node_data.get("body", []) - if body_params and any(param.get("required", False) for param in body_params): - body_data = webhook_data.get("body", {}) - raw_content = body_data.get("raw", "") - if not raw_content or not isinstance(raw_content, str): - return {"valid": False, "error": "Required body content missing for text/plain request"} + if configured_content_type != request_content_type: + return cls._validation_error( + f"Content-type mismatch. Expected {configured_content_type}, got {request_content_type}" + ) - elif configured_content_type == "application/json": - # For application/json, validate both existence and types of parameters - body_params = node_data.get("body", []) - body_data = webhook_data.get("body", {}) + return {"valid": True} - for body_param in body_params: - param_name = body_param.get("name", "") - param_type = body_param.get("type", SegmentType.STRING) - is_required = body_param.get("required", False) + @classmethod + def _extract_content_type(cls, headers: dict[str, Any]) -> str: + """Extract and normalize content-type from headers.""" + content_type = headers.get("Content-Type", "").lower() + if not content_type: + content_type = headers.get("content-type", "application/json").lower() + # Extract the main content type (ignore parameters like boundary) + return content_type.split(";")[0].strip() - # Handle regular JSON parameters - param_exists = param_name in body_data + @classmethod + def _validate_headers_and_params(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate required headers and query parameters.""" + # Validate required headers (case-insensitive) + webhook_headers_lower = {k.lower(): v for k, v in webhook_data["headers"].items()} + for header in node_data.get("headers", []): + if header.get("required", False): + header_name = header.get("name", "") + if header_name.lower() not in webhook_headers_lower: + return cls._validation_error(f"Required header missing: {header_name}") - # Check if required parameter exists - if is_required and not param_exists: - return {"valid": False, "error": f"Required body parameter missing: {param_name}"} + # Validate required query parameters + for param in node_data.get("params", []): + if param.get("required", False): + param_name = param.get("name", "") + if param_name not in webhook_data["query_params"]: + return cls._validation_error(f"Required query parameter missing: {param_name}") - # Validate parameter type if it exists - if param_exists: - param_value = body_data[param_name] - validation_result = cls._validate_json_parameter_type(param_name, param_value, param_type) - if not validation_result["valid"]: - return validation_result + return {"valid": True} - elif configured_content_type == "application/x-www-form-urlencoded": - # For form-urlencoded data, all values must be strings - no other types allowed - body_params = node_data.get("body", []) - body_data = webhook_data.get("body", {}) + @classmethod + def _validate_body_by_content_type( + cls, webhook_data: dict[str, Any], node_data: dict[str, Any], content_type: str + ) -> dict[str, Any]: + """Route body validation to appropriate validator based on content type.""" + validators = { + "text/plain": cls._validate_text_plain_body, + "application/octet-stream": cls._validate_octet_stream_body, + "application/json": cls._validate_json_body, + "application/x-www-form-urlencoded": cls._validate_form_urlencoded_body, + "multipart/form-data": cls._validate_multipart_body, + } - for body_param in body_params: - param_name = body_param.get("name", "") - param_type = body_param.get("type", SegmentType.STRING) - is_required = body_param.get("required", False) + validator = validators.get(content_type) + if not validator: + raise ValueError(f"Unsupported Content-Type for validation: {content_type}") - param_exists = param_name in body_data - if is_required and not param_exists: - return {"valid": False, "error": f"Required body parameter missing: {param_name}"} + return validator(webhook_data, node_data) - # Ensure the actual value is also a string - if param_exists and param_type != SegmentType.STRING: - param_value = body_data[param_name] - validation_result = cls._validate_form_parameter_type(param_name, param_value, param_type) - if not validation_result["valid"]: - return validation_result + @classmethod + def _validate_text_plain_body(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate text/plain body.""" + body_params = node_data.get("body", []) + if body_params and any(param.get("required", False) for param in body_params): + body_data = webhook_data.get("body", {}) + raw_content = body_data.get("raw", "") + if not raw_content or not isinstance(raw_content, str): + return cls._validation_error("Required body content missing for text/plain request") + return {"valid": True} - elif configured_content_type == "multipart/form-data": - # For multipart data, supports both strings and files - body_params = node_data.get("body", []) - body_data = webhook_data.get("body", {}) + @classmethod + def _validate_octet_stream_body(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate application/octet-stream body.""" + body_params = node_data.get("body", []) + if body_params and any(param.get("required", False) for param in body_params): + body_data = webhook_data.get("body", {}) + raw_content = body_data.get("raw", "") + if not raw_content or not isinstance(raw_content, bytes): + return cls._validation_error("Required body content missing for application/octet-stream request") + return {"valid": True} - for body_param in body_params: - param_name = body_param.get("name", "") - param_type = body_param.get("type", SegmentType.STRING) - is_required = body_param.get("required", False) + @classmethod + def _validate_json_body(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate application/json body.""" + body_params = node_data.get("body", []) + body_data = webhook_data.get("body", {}) - if param_type == SegmentType.FILE: - # File parameters are handled separately in files dict - file_obj = webhook_data.get("files", {}).get(param_name) - if is_required and not file_obj: - return {"valid": False, "error": f"Required file parameter missing: {param_name}"} - else: - # Multipart form data parameters are all strings - param_exists = param_name in body_data + for body_param in body_params: + param_name = body_param.get("name", "") + param_type = body_param.get("type", SegmentType.STRING) + is_required = body_param.get("required", False) - if is_required and not param_exists: - return {"valid": False, "error": f"Required body parameter missing: {param_name}"} + param_exists = param_name in body_data - # For form data, validate that non-string types can be converted - if param_exists and param_type != SegmentType.STRING: - param_value = body_data[param_name] - validation_result = cls._validate_form_parameter_type(param_name, param_value, param_type) - if not validation_result["valid"]: - return validation_result + if is_required and not param_exists: + return cls._validation_error(f"Required body parameter missing: {param_name}") + if param_exists: + param_value = body_data[param_name] + validation_result = cls._validate_json_parameter_type(param_name, param_value, param_type) + if not validation_result["valid"]: + return validation_result + + return {"valid": True} + + @classmethod + def _validate_form_urlencoded_body(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate application/x-www-form-urlencoded body.""" + body_params = node_data.get("body", []) + body_data = webhook_data.get("body", {}) + + for body_param in body_params: + param_name = body_param.get("name", "") + param_type = body_param.get("type", SegmentType.STRING) + is_required = body_param.get("required", False) + + param_exists = param_name in body_data + if is_required and not param_exists: + return cls._validation_error(f"Required body parameter missing: {param_name}") + + if param_exists and param_type != SegmentType.STRING: + param_value = body_data[param_name] + validation_result = cls._validate_form_parameter_type(param_name, param_value, param_type) + if not validation_result["valid"]: + return validation_result + + return {"valid": True} + + @classmethod + def _validate_multipart_body(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + """Validate multipart/form-data body.""" + body_params = node_data.get("body", []) + body_data = webhook_data.get("body", {}) + + for body_param in body_params: + param_name = body_param.get("name", "") + param_type = body_param.get("type", SegmentType.STRING) + is_required = body_param.get("required", False) + + if param_type == SegmentType.FILE: + file_obj = webhook_data.get("files", {}).get(param_name) + if is_required and not file_obj: + return cls._validation_error(f"Required file parameter missing: {param_name}") else: - raise ValueError(f"Unsupported Content-Type for validation: {configured_content_type}") + param_exists = param_name in body_data - return {"valid": True} + if is_required and not param_exists: + return cls._validation_error(f"Required body parameter missing: {param_name}") - except Exception: - logger.exception("Validation error") - return {"valid": False, "error": "Validation failed"} + if param_exists and param_type != SegmentType.STRING: + param_value = body_data[param_name] + validation_result = cls._validate_form_parameter_type(param_name, param_value, param_type) + if not validation_result["valid"]: + return validation_result + + return {"valid": True} + + @classmethod + def _validation_error(cls, error_message: str) -> dict[str, Any]: + """Create a standard validation error response.""" + return {"valid": False, "error": error_message} @classmethod def _validate_json_parameter_type(cls, param_name: str, param_value: Any, param_type: str) -> dict[str, Any]: """Validate JSON parameter type against expected type.""" try: - if param_type == SegmentType.STRING: - if not isinstance(param_value, str): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be a string, got {type(param_value).__name__}", - } + # Define type validators + type_validators = { + SegmentType.STRING: (lambda v: isinstance(v, str), "string"), + SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"), + SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"), + SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"), + SegmentType.ARRAY_STRING: ( + lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v), + "array of strings", + ), + SegmentType.ARRAY_NUMBER: ( + lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v), + "array of numbers", + ), + SegmentType.ARRAY_BOOLEAN: ( + lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v), + "array of booleans", + ), + SegmentType.ARRAY_OBJECT: ( + lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v), + "array of objects", + ), + } - elif param_type == SegmentType.NUMBER: - if not isinstance(param_value, (int, float)): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be a number, got {type(param_value).__name__}", - } - - elif param_type == SegmentType.BOOLEAN: - if not isinstance(param_value, bool): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be a boolean, got {type(param_value).__name__}", - } - - elif param_type == SegmentType.OBJECT: - if not isinstance(param_value, dict): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be an object, got {type(param_value).__name__}", - } - - elif param_type == SegmentType.ARRAY_STRING: - if not isinstance(param_value, list): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be an array, got {type(param_value).__name__}", - } - if not all(isinstance(item, str) for item in param_value): - return {"valid": False, "error": f"Parameter '{param_name}' must be an array of strings"} - - elif param_type == SegmentType.ARRAY_NUMBER: - if not isinstance(param_value, list): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be an array, got {type(param_value).__name__}", - } - if not all(isinstance(item, (int, float)) for item in param_value): - return {"valid": False, "error": f"Parameter '{param_name}' must be an array of numbers"} - - elif param_type == SegmentType.ARRAY_BOOLEAN: - if not isinstance(param_value, list): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be an array, got {type(param_value).__name__}", - } - if not all(isinstance(item, bool) for item in param_value): - return {"valid": False, "error": f"Parameter '{param_name}' must be an array of booleans"} - - elif param_type == SegmentType.ARRAY_OBJECT: - if not isinstance(param_value, list): - return { - "valid": False, - "error": f"Parameter '{param_name}' must be an array, got {type(param_value).__name__}", - } - if not all(isinstance(item, dict) for item in param_value): - return {"valid": False, "error": f"Parameter '{param_name}' must be an array of objects"} - - else: - # Unknown type, skip validation + # Get validator for the type + validator_info = type_validators.get(param_type) + if not validator_info: logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + return {"valid": True} + + validator, expected_type = validator_info + + # Validate the parameter + if not validator(param_value): + # Check if it's an array type first + if param_type.startswith("array") and not isinstance(param_value, list): + actual_type = type(param_value).__name__ + error_msg = f"Parameter '{param_name}' must be an array, got {actual_type}" + else: + actual_type = type(param_value).__name__ + # Format error message based on expected type + if param_type.startswith("array"): + error_msg = f"Parameter '{param_name}' must be an {expected_type}" + elif expected_type in ["string", "number", "boolean"]: + error_msg = f"Parameter '{param_name}' must be a {expected_type}, got {actual_type}" + else: + error_msg = f"Parameter '{param_name}' must be an {expected_type}, got {actual_type}" + + return {"valid": False, "error": error_msg} return {"valid": True} @@ -408,43 +471,49 @@ class WebhookService: def _validate_form_parameter_type(cls, param_name: str, param_value: str, param_type: str) -> dict[str, Any]: """Validate form parameter type against expected type. Form data are always strings but can be converted.""" try: - # Form data values are always strings, but we can validate if they can be interpreted as other types - if param_type == SegmentType.STRING: - # String is always valid - return {"valid": True} + # Define form type converters and validators + form_validators = { + SegmentType.STRING: (lambda _: True, None), # String is always valid + SegmentType.NUMBER: (lambda v: cls._can_convert_to_number(v), "a valid number"), + SegmentType.BOOLEAN: ( + lambda v: v.lower() in ["true", "false", "1", "0", "yes", "no"], + "a boolean value", + ), + } - elif param_type == SegmentType.NUMBER: - # Check if string can be converted to number - try: - float(param_value) - return {"valid": True} - except ValueError: - return { - "valid": False, - "error": f"Parameter '{param_name}' must be a valid number, got '{param_value}'", - } - - elif param_type == SegmentType.BOOLEAN: - # Check if string represents a boolean - if param_value.lower() in ["true", "false", "1", "0", "yes", "no"]: - return {"valid": True} - else: - return { - "valid": False, - "error": f"Parameter '{param_name}' must be a boolean value, got '{param_value}'", - } - - else: - # For other types (object, arrays), form data is not suitable + # Get validator for the type + validator_info = form_validators.get(param_type) + if not validator_info: + # Unsupported type for form data return { "valid": False, "error": f"Parameter '{param_name}' type '{param_type}' is not supported for form data.", } + validator, expected_format = validator_info + + # Validate the parameter + if not validator(param_value): + return { + "valid": False, + "error": f"Parameter '{param_name}' must be {expected_format}, got '{param_value}'", + } + + return {"valid": True} + except Exception: logger.exception("Form type validation error for parameter %s", param_name) return {"valid": False, "error": f"Form type validation failed for parameter '{param_name}'"} + @classmethod + def _can_convert_to_number(cls, value: str) -> bool: + """Check if a string can be converted to a number.""" + try: + float(value) + return True + except ValueError: + return False + @classmethod def trigger_workflow_execution( cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow