diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index 1cb9c0967b..4b06174ee1 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -1,9 +1,10 @@ name: Check i18n Files and Create PR on: - pull_request: - types: [closed] + push: branches: [main] + paths: + - 'web/i18n/en-US/*.ts' permissions: contents: write @@ -11,7 +12,7 @@ permissions: jobs: check-and-update: - if: github.event.pull_request.merged == true + if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest defaults: run: @@ -19,7 +20,7 @@ jobs: steps: - uses: actions/checkout@v4 with: - fetch-depth: 2 # last 2 commits + fetch-depth: 2 token: ${{ secrets.GITHUB_TOKEN }} - name: Check for file changes in i18n/en-US @@ -31,6 +32,13 @@ jobs: echo "Changed files: $changed_files" if [ -n "$changed_files" ]; then echo "FILES_CHANGED=true" >> $GITHUB_ENV + file_args="" + for file in $changed_files; do + filename=$(basename "$file" .ts) + file_args="$file_args --file=$filename" + done + echo "FILE_ARGS=$file_args" >> $GITHUB_ENV + echo "File arguments: $file_args" else echo "FILES_CHANGED=false" >> $GITHUB_ENV fi @@ -55,7 +63,7 @@ jobs: - name: Generate i18n translations if: env.FILES_CHANGED == 'true' - run: pnpm run auto-gen-i18n + run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} - name: Create Pull Request if: env.FILES_CHANGED == 'true' diff --git a/.gitignore b/.gitignore index c60957db72..5c68d89a4d 100644 --- a/.gitignore +++ b/.gitignore @@ -215,10 +215,4 @@ mise.toml # AI Assistant .roo/ api/.env.backup - -# Clickzetta test credentials -.env.clickzetta -.env.clickzetta.test - -# Clickzetta plugin development folder (keep local, ignore for PR) -clickzetta/ +/clickzetta diff --git a/api/Dockerfile b/api/Dockerfile index e097b5811e..d69291f7ea 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -19,7 +19,7 @@ RUN apt-get update \ # Install Python dependencies COPY pyproject.toml uv.lock ./ -RUN uv sync --locked +RUN uv sync --locked --no-dev # production stage FROM base AS production diff --git a/api/commands.py b/api/commands.py index 27f558c339..d9a24fa4a8 100644 --- a/api/commands.py +++ b/api/commands.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from flask import current_app from pydantic import TypeAdapter from sqlalchemy import select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError from configs import dify_config from constants.languages import languages @@ -186,8 +186,8 @@ def migrate_annotation_vector_database(): ) if not apps: break - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for app in apps: @@ -313,8 +313,8 @@ def migrate_knowledge_vector_database(): ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for dataset in datasets: @@ -566,8 +566,8 @@ def old_metadata_migration(): .order_by(DatasetDocument.created_at.desc()) ) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise if not documents: break for document in documents: diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 9f1646ea7d..4dbc8207f1 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -330,17 +330,17 @@ class HttpConfig(BaseSettings): def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ - PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests") - ] = 10 + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field( + ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10 + ) - HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ - PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests") - ] = 60 + HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( + ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 + ) - HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ - PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests") - ] = 20 + HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( + ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 + ) HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( description="Maximum allowed size in bytes for binary data in HTTP requests", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9fe32dde6d..1cc13d669c 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -28,6 +28,12 @@ from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +def _validate_description_length(description): + if description and len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + class AppListApi(Resource): @setup_required @login_required @@ -94,7 +100,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") @@ -146,7 +152,7 @@ class AppApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -189,7 +195,7 @@ class AppCopyApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 13eab40476..cbc234deb7 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -41,7 +41,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -113,7 +113,7 @@ class DatasetListApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d964e27819..b26f29d98d 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) from . import index -from .app import annotation, app, audio, completion, conversation, file, message, site, workflow +from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ca91da80c1..ba705f71e2 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException): error_code = "unsupported_file_type" description = "File type not allowed." code = 415 + + +class FileNotFoundError(BaseHTTPException): + error_code = "file_not_found" + description = "The requested file was not found." + code = 404 + + +class FileAccessDeniedError(BaseHTTPException): + error_code = "file_access_denied" + description = "Access to the requested file is denied." + code = 403 diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py new file mode 100644 index 0000000000..57141033d1 --- /dev/null +++ b/api/controllers/service_api/app/file_preview.py @@ -0,0 +1,186 @@ +import logging +from urllib.parse import quote + +from flask import Response +from flask_restful import Resource, reqparse + +from controllers.service_api import api +from controllers.service_api.app.error import ( + FileAccessDeniedError, + FileNotFoundError, +) +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import App, EndUser, Message, MessageFile, UploadFile + +logger = logging.getLogger(__name__) + + +class FilePreviewApi(Resource): + """ + Service API File Preview endpoint + + Provides secure file preview/download functionality for external API users. + Files can only be accessed if they belong to messages within the requesting app's context. + """ + + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + def get(self, app_model: App, end_user: EndUser, file_id: str): + """ + Preview/Download a file that was uploaded via Service API + + Args: + app_model: The authenticated app model + end_user: The authenticated end user (optional) + file_id: UUID of the file to preview + + Query Parameters: + user: Optional user identifier + as_attachment: Boolean, whether to download as attachment (default: false) + + Returns: + Stream response with file content + + Raises: + FileNotFoundError: File does not exist + FileAccessDeniedError: File access denied (not owned by app) + """ + file_id = str(file_id) + + # Parse query parameters + parser = reqparse.RequestParser() + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") + args = parser.parse_args() + + # Validate file ownership and get file objects + message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) + + # Get file content generator + try: + generator = storage.load(upload_file.key, stream=True) + except Exception as e: + raise FileNotFoundError(f"Failed to load file content: {str(e)}") + + # Build response with appropriate headers + response = self._build_file_response(generator, upload_file, args["as_attachment"]) + + return response + + def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]: + """ + Validate that the file belongs to a message within the requesting app's context + + Security validations performed: + 1. File exists in MessageFile table (was used in a conversation) + 2. Message belongs to the requesting app + 3. UploadFile record exists and is accessible + 4. File tenant matches app tenant (additional security layer) + + Args: + file_id: UUID of the file to validate + app_id: UUID of the requesting app + + Returns: + Tuple of (MessageFile, UploadFile) if validation passes + + Raises: + FileNotFoundError: File or related records not found + FileAccessDeniedError: File does not belong to the app's context + """ + try: + # Input validation + if not file_id or not app_id: + raise FileAccessDeniedError("Invalid file or app identifier") + + # First, find the MessageFile that references this upload file + message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + + if not message_file: + raise FileNotFoundError("File not found in message context") + + # Get the message and verify it belongs to the requesting app + message = ( + db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + ) + + if not message: + raise FileAccessDeniedError("File access denied: not owned by requesting app") + + # Get the actual upload file record + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + raise FileNotFoundError("Upload file record not found") + + # Additional security: verify tenant isolation + app = db.session.query(App).where(App.id == app_id).first() + if app and upload_file.tenant_id != app.tenant_id: + raise FileAccessDeniedError("File access denied: tenant mismatch") + + return message_file, upload_file + + except (FileNotFoundError, FileAccessDeniedError): + # Re-raise our custom exceptions + raise + except Exception as e: + # Log unexpected errors for debugging + logger.exception( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": str(e)}, + ) + raise FileAccessDeniedError("File access validation failed") + + def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response: + """ + Build Flask Response object with appropriate headers for file streaming + + Args: + generator: File content generator from storage + upload_file: UploadFile database record + as_attachment: Whether to set Content-Disposition as attachment + + Returns: + Flask Response object with streaming file content + """ + response = Response( + generator, + mimetype=upload_file.mime_type, + direct_passthrough=True, + headers={}, + ) + + # Add Content-Length if known + if upload_file.size and upload_file.size > 0: + response.headers["Content-Length"] = str(upload_file.size) + + # Add Accept-Ranges header for audio/video files to support seeking + if upload_file.mime_type in [ + "audio/mpeg", + "audio/wav", + "audio/mp4", + "audio/ogg", + "audio/flac", + "audio/aac", + "video/mp4", + "video/webm", + "video/quicktime", + "audio/x-m4a", + ]: + response.headers["Accept-Ranges"] = "bytes" + + # Set Content-Disposition for downloads + if as_attachment and upload_file.name: + encoded_filename = quote(upload_file.name) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + # Override content-type for downloads to force download + response.headers["Content-Type"] = "application/octet-stream" + + # Add caching headers for performance + response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour + + return response + + +# Register the API endpoint +api.add_resource(FilePreviewApi, "/files//preview") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index a499719fc3..29eef41253 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -29,7 +29,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 94a525a75d..197859e8f3 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,5 +1,6 @@ from flask import request from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import Unauthorized from controllers.common import fields from controllers.web import api @@ -75,14 +76,14 @@ class AppWebAuthPermission(Resource): try: auth_header = request.headers.get("Authorization") if auth_header is None: - raise + raise Unauthorized("Authorization header is missing.") if " " not in auth_header: - raise + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": - raise + raise Unauthorized("Authorization scheme must be 'Bearer'") decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a75e17af64..3de2f5ca9e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ): return - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, - ConversationVariable.conversation_id == self.conversation.id, - ) - with Session(db.engine) as session: - db_conversation_variables = session.scalars(stmt).all() - if not db_conversation_variables: - # Create conversation variables if they don't exist. - db_conversation_variables = [ - ConversationVariable.from_variable( - app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable - ) - for variable in self._workflow.conversation_variables - ] - session.add_all(db_conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in db_conversation_variables] - - session.commit() + # Initialize conversation variables + conversation_variables = self._initialize_conversation_variables() # Create a variable pool. system_inputs = SystemVariable( @@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): message_id=message_id, trace_manager=app_generate_entity.trace_manager, ) + + def _initialize_conversation_variables(self) -> list[VariableUnion]: + """ + Initialize conversation variables for the current conversation. + + This method: + 1. Loads existing variables from the database + 2. Creates new variables if none exist + 3. Syncs missing variables from the workflow definition + + :return: List of conversation variables ready for use + """ + with Session(db.engine) as session: + existing_variables = self._load_existing_conversation_variables(session) + + if not existing_variables: + # First time initialization - create all variables + existing_variables = self._create_all_conversation_variables(session) + else: + # Check and add any missing variables from the workflow + existing_variables = self._sync_missing_conversation_variables(session, existing_variables) + + # Convert to Variable objects for use in the workflow + conversation_variables = [var.to_variable() for var in existing_variables] + + session.commit() + return cast(list[VariableUnion], conversation_variables) + + def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Load existing conversation variables from the database. + + :param session: Database session + :return: List of existing conversation variables + """ + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + return list(session.scalars(stmt).all()) + + def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Create all conversation variables for a new conversation. + + :param session: Database session + :return: List of created conversation variables + """ + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in self._workflow.conversation_variables + ] + + if new_variables: + session.add_all(new_variables) + + return new_variables + + def _sync_missing_conversation_variables( + self, session: Session, existing_variables: list[ConversationVariable] + ) -> list[ConversationVariable]: + """ + Sync missing conversation variables from the workflow definition. + + This handles the case where new variables are added to a workflow + after conversations have already been created. + + :param session: Database session + :param existing_variables: List of existing conversation variables + :return: Updated list including any newly created variables + """ + # Get IDs of existing and workflow variables + existing_ids = {var.id for var in existing_variables} + workflow_variables = {var.id: var for var in self._workflow.conversation_variables} + + # Find missing variable IDs + missing_ids = set(workflow_variables.keys()) - existing_ids + + if not missing_ids: + return existing_variables + + # Create missing variables with their default values + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, + conversation_id=self.conversation.id, + variable=workflow_variables[var_id], + ) + for var_id in missing_ids + ] + + session.add_all(new_variables) + + # Return combined list + return existing_variables + new_variables diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index f0e9425e3f..f3b9dbf758 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -23,6 +23,7 @@ from core.app.entities.task_entities import ( MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + StreamEvent, WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator @@ -180,11 +181,15 @@ class MessageCycleManager: :param message_id: message id :return: """ + message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first() + event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE + return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, + event=event_type, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index f9bb0149b5..4d2d590e07 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel): continue status = ModelStatus.ACTIVE - if m.model in model_setting_map: + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/core/rag/datasource/vdb/clickzetta/README.md index 40229f8d44..2ee3e657d3 100644 --- a/api/core/rag/datasource/vdb/clickzetta/README.md +++ b/api/core/rag/datasource/vdb/clickzetta/README.md @@ -185,6 +185,6 @@ Clickzetta supports advanced full-text search with multiple analyzers: ## References -- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md) -- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md) -- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/) +- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) +- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) +- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index d295bab5aa..1059b855a2 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,9 +1,11 @@ import json import logging import queue +import re import threading +import time import uuid -from typing import Any, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import clickzetta # type: ignore from pydantic import BaseModel, model_validator @@ -67,6 +69,243 @@ class ClickzettaConfig(BaseModel): return values +class ClickzettaConnectionPool: + """ + Global connection pool for ClickZetta connections. + Manages connection reuse across ClickzettaVector instances. + """ + + _instance: Optional["ClickzettaConnectionPool"] = None + _lock = threading.Lock() + + def __init__(self): + self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)] + self._pool_locks: dict[str, threading.Lock] = {} + self._max_pool_size = 5 # Maximum connections per configuration + self._connection_timeout = 300 # 5 minutes timeout + self._cleanup_thread: Optional[threading.Thread] = None + self._shutdown = False + self._start_cleanup_thread() + + @classmethod + def get_instance(cls) -> "ClickzettaConnectionPool": + """Get singleton instance of connection pool.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _get_config_key(self, config: ClickzettaConfig) -> str: + """Generate unique key for connection configuration.""" + return ( + f"{config.username}:{config.instance}:{config.service}:" + f"{config.workspace}:{config.vcluster}:{config.schema_name}" + ) + + def _create_connection(self, config: ClickzettaConfig) -> "Connection": + """Create a new ClickZetta connection.""" + max_retries = 3 + retry_delay = 1.0 + + for attempt in range(max_retries): + try: + connection = clickzetta.connect( + username=config.username, + password=config.password, + instance=config.instance, + service=config.service, + workspace=config.workspace, + vcluster=config.vcluster, + schema=config.schema_name, + ) + + # Configure connection session settings + self._configure_connection(connection) + logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries) + return connection + except Exception: + logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries) + if attempt < max_retries - 1: + time.sleep(retry_delay * (2**attempt)) + else: + raise + + raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") + + def _configure_connection(self, connection: "Connection") -> None: + """Configure connection session settings.""" + try: + with connection.cursor() as cursor: + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) + + try: + # Use quote mode for string literal escaping + cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") + + # Apply performance optimization hints + performance_hints = [ + # Vector index optimization + "SET cz.storage.parquet.vector.index.read.memory.cache = true", + "SET cz.storage.parquet.vector.index.read.local.cache = false", + # Query optimization + "SET cz.sql.table.scan.push.down.filter = true", + "SET cz.sql.table.scan.enable.ensure.filter = true", + "SET cz.storage.always.prefetch.internal = true", + "SET cz.optimizer.generate.columns.always.valid = true", + "SET cz.sql.index.prewhere.enabled = true", + # Storage optimization + "SET cz.storage.parquet.enable.io.prefetch = false", + "SET cz.optimizer.enable.mv.rewrite = false", + "SET cz.sql.dump.as.lz4 = true", + "SET cz.optimizer.limited.optimization.naive.query = true", + "SET cz.sql.table.scan.enable.push.down.log = false", + "SET cz.storage.use.file.format.local.stats = false", + "SET cz.storage.local.file.object.cache.level = all", + # Job execution optimization + "SET cz.sql.job.fast.mode = true", + "SET cz.storage.parquet.non.contiguous.read = true", + "SET cz.sql.compaction.after.commit = true", + ] + + for hint in performance_hints: + cursor.execute(hint) + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + except Exception: + logger.exception("Failed to configure connection, continuing with defaults") + + def _is_connection_valid(self, connection: "Connection") -> bool: + """Check if connection is still valid.""" + try: + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + return True + except Exception: + return False + + def get_connection(self, config: ClickzettaConfig) -> "Connection": + """Get a connection from the pool or create a new one.""" + config_key = self._get_config_key(config) + + # Ensure pool lock exists + if config_key not in self._pool_locks: + with self._lock: + if config_key not in self._pool_locks: + self._pool_locks[config_key] = threading.Lock() + self._pools[config_key] = [] + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + current_time = time.time() + + # Try to reuse existing connection + while pool: + connection, last_used = pool.pop(0) + + # Check if connection is not expired and still valid + if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection): + logger.debug("Reusing ClickZetta connection from pool") + return connection + else: + # Connection expired or invalid, close it + try: + connection.close() + except Exception: + pass + + # No valid connection found, create new one + return self._create_connection(config) + + def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None: + """Return a connection to the pool.""" + config_key = self._get_config_key(config) + + if config_key not in self._pool_locks: + # Pool was cleaned up, just close the connection + try: + connection.close() + except Exception: + pass + return + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + + # Only return to pool if not at capacity and connection is valid + if len(pool) < self._max_pool_size and self._is_connection_valid(connection): + pool.append((connection, time.time())) + logger.debug("Returned ClickZetta connection to pool") + else: + # Pool full or connection invalid, close it + try: + connection.close() + except Exception: + pass + + def _cleanup_expired_connections(self) -> None: + """Clean up expired connections from all pools.""" + current_time = time.time() + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + valid_connections = [] + + for connection, last_used in pool: + if current_time - last_used < self._connection_timeout: + valid_connections.append((connection, last_used)) + else: + try: + connection.close() + except Exception: + pass + + self._pools[config_key] = valid_connections + + def _start_cleanup_thread(self) -> None: + """Start background thread for connection cleanup.""" + + def cleanup_worker(): + while not self._shutdown: + try: + time.sleep(60) # Cleanup every minute + if not self._shutdown: + self._cleanup_expired_connections() + except Exception: + logger.exception("Error in connection pool cleanup") + + self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + self._cleanup_thread.start() + + def shutdown(self) -> None: + """Shutdown connection pool and close all connections.""" + self._shutdown = True + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + for connection, _ in pool: + try: + connection.close() + except Exception: + pass + pool.clear() + + class ClickzettaVector(BaseVector): """ Clickzetta vector storage implementation. @@ -82,71 +321,74 @@ class ClickzettaVector(BaseVector): super().__init__(collection_name) self._config = config self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name - self._connection: Optional["Connection"] = None - self._init_connection() + self._connection_pool = ClickzettaConnectionPool.get_instance() self._init_write_queue() - def _init_connection(self): - """Initialize Clickzetta connection.""" - self._connection = clickzetta.connect( - username=self._config.username, - password=self._config.password, - instance=self._config.instance, - service=self._config.service, - workspace=self._config.workspace, - vcluster=self._config.vcluster, - schema=self._config.schema_name - ) + def _get_connection(self) -> "Connection": + """Get a connection from the pool.""" + return self._connection_pool.get_connection(self._config) - # Set session parameters for better string handling and performance optimization - if self._connection is not None: - with self._connection.cursor() as cursor: - # Use quote mode for string literal escaping to handle quotes better - cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") - logger.info("Set string literal escape mode to 'quote' for better quote handling") + def _return_connection(self, connection: "Connection") -> None: + """Return a connection to the pool.""" + self._connection_pool.return_connection(self._config, connection) - # Performance optimization hints for vector operations - self._set_performance_hints(cursor) + class ConnectionContext: + """Context manager for borrowing and returning connections.""" - def _set_performance_hints(self, cursor): - """Set ClickZetta performance optimization hints for vector operations.""" + def __init__(self, vector_instance: "ClickzettaVector"): + self.vector = vector_instance + self.connection: Optional[Connection] = None + + def __enter__(self) -> "Connection": + self.connection = self.vector._get_connection() + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection: + self.vector._return_connection(self.connection) + + def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + """Get a connection context manager.""" + return self.ConnectionContext(self) + + def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict: + """ + Parse metadata from JSON string with proper error handling and fallback. + + Args: + raw_metadata: Raw JSON string from database + row_id: Row ID for fallback document_id + + Returns: + Parsed metadata dict with guaranteed required fields + """ try: - # Performance optimization hints for vector operations and query processing - performance_hints = [ - # Vector index optimization - "SET cz.storage.parquet.vector.index.read.memory.cache = true", - "SET cz.storage.parquet.vector.index.read.local.cache = false", + if raw_metadata: + metadata = json.loads(raw_metadata) - # Query optimization - "SET cz.sql.table.scan.push.down.filter = true", - "SET cz.sql.table.scan.enable.ensure.filter = true", - "SET cz.storage.always.prefetch.internal = true", - "SET cz.optimizer.generate.columns.always.valid = true", - "SET cz.sql.index.prewhere.enabled = true", + # Handle double-encoded JSON + if isinstance(metadata, str): + metadata = json.loads(metadata) - # Storage optimization - "SET cz.storage.parquet.enable.io.prefetch = false", - "SET cz.optimizer.enable.mv.rewrite = false", - "SET cz.sql.dump.as.lz4 = true", - "SET cz.optimizer.limited.optimization.naive.query = true", - "SET cz.sql.table.scan.enable.push.down.log = false", - "SET cz.storage.use.file.format.local.stats = false", - "SET cz.storage.local.file.object.cache.level = all", + # Ensure we have a dict + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError): + logger.exception("JSON parsing failed for metadata") + # Fallback: extract document_id with regex + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - # Job execution optimization - "SET cz.sql.job.fast.mode = true", - "SET cz.storage.parquet.non.contiguous.read = true", - "SET cz.sql.compaction.after.commit = true" - ] + # Ensure required fields are set + metadata["doc_id"] = row_id # segment id - for hint in performance_hints: - cursor.execute(hint) + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row_id # fallback to segment id - logger.info("Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints)) - - except Exception: - # Catch any errors setting performance hints but continue with defaults - logger.exception("Failed to set some performance hints, continuing with default settings") + return metadata @classmethod def _init_write_queue(cls): @@ -205,24 +447,33 @@ class ClickzettaVector(BaseVector): return "clickzetta" def _ensure_connection(self) -> "Connection": - """Ensure connection is available and return it.""" - if self._connection is None: - raise RuntimeError("Database connection not initialized") - return self._connection + """Get a connection from the pool.""" + return self._get_connection() def _table_exists(self) -> bool: """Check if the table exists.""" try: - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") - return True - except (RuntimeError, ValueError) as e: - if "table or view not found" in str(e).lower(): + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") + return True + except Exception as e: + error_message = str(e).lower() + # Handle ClickZetta specific "table or view not found" errors + if any( + phrase in error_message + for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"] + ): + logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name) return False else: - # Re-raise if it's a different error - raise + # For other connection/permission errors, log warning but return False to avoid blocking cleanup + logger.exception( + "Table existence check failed for %s.%s, assuming it doesn't exist", + self._config.schema_name, + self._table_name, + ) + return False def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """Create the collection and add initial documents.""" @@ -254,17 +505,17 @@ class ClickzettaVector(BaseVector): ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' """ - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(create_table_sql) - logger.info("Created table %s.%s", self._config.schema_name, self._table_name) + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(create_table_sql) + logger.info("Created table %s.%s", self._config.schema_name, self._table_name) - # Create vector index - self._create_vector_index(cursor) + # Create vector index + self._create_vector_index(cursor) - # Create inverted index for full-text search if enabled - if self._config.enable_inverted_index: - self._create_inverted_index(cursor) + # Create inverted index for full-text search if enabled + if self._config.enable_inverted_index: + self._create_inverted_index(cursor) def _create_vector_index(self, cursor): """Create HNSW vector index for similarity search.""" @@ -298,9 +549,7 @@ class ClickzettaVector(BaseVector): logger.info("Created vector index: %s", index_name) except (RuntimeError, ValueError) as e: error_msg = str(e).lower() - if ("already exists" in error_msg or - "already has index" in error_msg or - "with the same type" in error_msg): + if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg: logger.info("Vector index already exists: %s", e) else: logger.exception("Failed to create vector index") @@ -318,9 +567,11 @@ class ClickzettaVector(BaseVector): for idx in existing_indexes: idx_str = str(idx).lower() # More precise check: look for inverted index specifically on the content column - if ("inverted" in idx_str and - Field.CONTENT_KEY.value.lower() in idx_str and - (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)): + if ( + "inverted" in idx_str + and Field.CONTENT_KEY.value.lower() in idx_str + and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) + ): logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) return except (RuntimeError, ValueError) as e: @@ -340,11 +591,12 @@ class ClickzettaVector(BaseVector): except (RuntimeError, ValueError) as e: error_msg = str(e).lower() # Handle ClickZetta specific error messages - if (("already exists" in error_msg or - "already has index" in error_msg or - "with the same type" in error_msg or - "cannot create inverted index" in error_msg) and - "already has index" in error_msg): + if ( + "already exists" in error_msg + or "already has index" in error_msg + or "with the same type" in error_msg + or "cannot create inverted index" in error_msg + ) and "already has index" in error_msg: logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) # Try to get the existing index name for logging try: @@ -360,7 +612,6 @@ class ClickzettaVector(BaseVector): logger.warning("Failed to create inverted index: %s", e) # Continue without inverted index - full-text search will fall back to LIKE - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): """Add documents with embeddings to the collection.""" if not documents: @@ -370,14 +621,20 @@ class ClickzettaVector(BaseVector): total_batches = (len(documents) + batch_size - 1) // batch_size for i in range(0, len(documents), batch_size): - batch_docs = documents[i:i + batch_size] - batch_embeddings = embeddings[i:i + batch_size] + batch_docs = documents[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] # Execute batch insert through write queue self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) - def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]], - batch_index: int, batch_size: int, total_batches: int): + def _insert_batch( + self, + batch_docs: list[Document], + batch_embeddings: list[list[float]], + batch_index: int, + batch_size: int, + total_batches: int, + ): """Insert a batch of documents using parameterized queries (executed in write worker thread).""" if not batch_docs or not batch_embeddings: logger.warning("Empty batch provided, skipping insertion") @@ -411,7 +668,7 @@ class ClickzettaVector(BaseVector): # According to ClickZetta docs, vector should be formatted as array string # for external systems: '[1.0, 2.0, 3.0]' - vector_str = '[' + ','.join(map(str, embedding)) + ']' + vector_str = "[" + ",".join(map(str, embedding)) + "]" data_rows.append([doc_id, content, metadata_json, vector_str]) # Check if we have any valid data to insert @@ -427,37 +684,53 @@ class ClickzettaVector(BaseVector): f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" ) - connection = self._ensure_connection() - with connection.cursor() as cursor: - try: - # Set session-level hints for batch insert operations - # Note: executemany doesn't support hints parameter, so we set them as session variables - cursor.execute("SET cz.sql.job.fast.mode = true") - cursor.execute("SET cz.sql.compaction.after.commit = true") - cursor.execute("SET cz.storage.always.prefetch.internal = true") + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Set session-level hints for batch insert operations + # Note: executemany doesn't support hints parameter, so we set them as session variables + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) - cursor.executemany(insert_sql, data_rows) - logger.info( - f"Inserted batch {batch_index // batch_size + 1}/{total_batches} " - f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)" - ) - except (RuntimeError, ValueError, TypeError, ConnectionError) as e: - logger.exception("Parameterized SQL execution failed for %d documents: %s", len(data_rows), e) - logger.exception("SQL template: %s", insert_sql) - logger.exception("Sample data row: %s", data_rows[0] if data_rows else 'None') - raise + try: + cursor.execute("SET cz.sql.job.fast.mode = true") + cursor.execute("SET cz.sql.compaction.after.commit = true") + cursor.execute("SET cz.storage.always.prefetch.internal = true") + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + cursor.executemany(insert_sql, data_rows) + logger.info( + "Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", + batch_index // batch_size + 1, + total_batches, + len(data_rows), + vector_dimension, + ) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) + logger.exception("SQL template: %s", insert_sql) + logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") + raise def text_exists(self, id: str) -> bool: """Check if a document exists by ID.""" + # Check if table exists first + if not self._table_exists(): + return False + safe_id = self._safe_doc_id(id) - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute( - f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", - [safe_id] - ) - result = cursor.fetchone() - return result[0] > 0 if result else False + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute( + f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", + binding_params=[safe_id], + ) + result = cursor.fetchone() + return result[0] > 0 if result else False def delete_by_ids(self, ids: list[str]) -> None: """Delete documents by IDs.""" @@ -475,13 +748,14 @@ class ClickzettaVector(BaseVector): def _delete_by_ids_impl(self, ids: list[str]) -> None: """Implementation of delete by IDs (executed in write worker thread).""" safe_ids = [self._safe_doc_id(id) for id in ids] - # Create properly escaped string literals for SQL - id_list = ",".join(f"'{id}'" for id in safe_ids) - sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(sql) + # Use parameterized query to prevent SQL injection + placeholders = ",".join("?" for _ in safe_ids) + sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})" + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(sql, binding_params=safe_ids) def delete_by_metadata_field(self, key: str, value: str) -> None: """Delete documents by metadata field.""" @@ -495,17 +769,28 @@ class ClickzettaVector(BaseVector): def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: """Implementation of delete by metadata field (executed in write worker thread).""" - connection = self._ensure_connection() - with connection.cursor() as cursor: - # Using JSON path to filter with parameterized query - # Note: JSON path requires literal key name, cannot be parameterized - # Use json_extract_string function for ClickZetta compatibility - sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} " - f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?") - cursor.execute(sql, [value]) + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Using JSON path to filter with parameterized query + # Note: JSON path requires literal key name, cannot be parameterized + # Use json_extract_string function for ClickZetta compatibility + sql = ( + f"DELETE FROM {self._config.schema_name}.{self._table_name} " + f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + ) + cursor.execute(sql, binding_params=[value]) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Search for documents by vector similarity.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + top_k = kwargs.get("top_k", 10) score_threshold = kwargs.get("score_threshold", 0.0) document_ids_filter = kwargs.get("document_ids_filter") @@ -532,15 +817,15 @@ class ClickzettaVector(BaseVector): distance_func = "COSINE_DISTANCE" if score_threshold > 0: query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" - filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, " - f"{query_vector_str}) < {2 - score_threshold}") + filter_clauses.append( + f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" + ) else: # For L2 distance, smaller is better distance_func = "L2_DISTANCE" if score_threshold > 0: query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" - filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, " - f"{query_vector_str}) < {score_threshold}") + filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" @@ -556,55 +841,31 @@ class ClickzettaVector(BaseVector): """ documents = [] - connection = self._ensure_connection() - with connection.cursor() as cursor: - # Use hints parameter for vector search optimization - search_hints = { - 'hints': { - 'sdk.job.timeout': 60, # Increase timeout for vector search - 'cz.sql.job.fast.mode': True, - 'cz.storage.parquet.vector.index.read.memory.cache': True + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for vector search optimization + search_hints = { + "hints": { + "sdk.job.timeout": 60, # Increase timeout for vector search + "cz.sql.job.fast.mode": True, + "cz.storage.parquet.vector.index.read.memory.cache": True, + } } - } - cursor.execute(search_sql, parameters=search_hints) - results = cursor.fetchall() + cursor.execute(search_sql, search_hints) + results = cursor.fetchall() - for row in results: - # Parse metadata from JSON string (may be double-encoded) - try: - if row[2]: - metadata = json.loads(row[2]) + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if not isinstance(metadata, dict): - metadata = {} + # Add score based on distance + if self._config.vector_distance_function == "cosine_distance": + metadata["score"] = 1 - (row[3] / 2) else: - metadata = {} - except (json.JSONDecodeError, TypeError) as e: - logger.error("JSON parsing failed: %s", e) - # Fallback: extract document_id with regex - import re - doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) - metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + metadata["score"] = 1 / (1 + row[3]) - # Ensure required fields are set - metadata["doc_id"] = row[0] # segment id - - # Ensure document_id exists (critical for Dify's format_retrieval_documents) - if "document_id" not in metadata: - metadata["document_id"] = row[0] # fallback to segment id - - # Add score based on distance - if self._config.vector_distance_function == "cosine_distance": - metadata["score"] = 1 - (row[3] / 2) - else: - metadata["score"] = 1 / (1 + row[3]) - - doc = Document(page_content=row[1], metadata=metadata) - documents.append(doc) + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) return documents @@ -614,6 +875,15 @@ class ClickzettaVector(BaseVector): logger.warning("Full-text search is not enabled. Enable inverted index in config.") return [] + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + top_k = kwargs.get("top_k", 10) document_ids_filter = kwargs.get("document_ids_filter") @@ -649,61 +919,70 @@ class ClickzettaVector(BaseVector): """ documents = [] - connection = self._ensure_connection() - with connection.cursor() as cursor: - try: - # Use hints parameter for full-text search optimization - fulltext_hints = { - 'hints': { - 'sdk.job.timeout': 30, # Timeout for full-text search - 'cz.sql.job.fast.mode': True, - 'cz.sql.index.prewhere.enabled': True + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Use hints parameter for full-text search optimization + fulltext_hints = { + "hints": { + "sdk.job.timeout": 30, # Timeout for full-text search + "cz.sql.job.fast.mode": True, + "cz.sql.index.prewhere.enabled": True, + } } - } - cursor.execute(search_sql, parameters=fulltext_hints) - results = cursor.fetchall() + cursor.execute(search_sql, fulltext_hints) + results = cursor.fetchall() - for row in results: - # Parse metadata from JSON string (may be double-encoded) - try: - if row[2]: - metadata = json.loads(row[2]) + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) - if not isinstance(metadata, dict): + if not isinstance(metadata, dict): + metadata = {} + else: metadata = {} - else: - metadata = {} - except (json.JSONDecodeError, TypeError) as e: - logger.error("JSON parsing failed: %s", e) - # Fallback: extract document_id with regex - import re - doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) - metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex - # Ensure required fields are set - metadata["doc_id"] = row[0] # segment id + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - # Ensure document_id exists (critical for Dify's format_retrieval_documents) - if "document_id" not in metadata: - metadata["document_id"] = row[0] # fallback to segment id + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id - # Add a relevance score for full-text search - metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores - doc = Document(page_content=row[1], metadata=metadata) - documents.append(doc) - except (RuntimeError, ValueError, TypeError, ConnectionError) as e: - logger.exception("Full-text search failed") - # Fallback to LIKE search if full-text search fails - return self._search_by_like(query, **kwargs) + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id + + # Add a relevance score for full-text search + metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Full-text search failed") + # Fallback to LIKE search if full-text search fails + return self._search_by_like(query, **kwargs) return documents def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]: """Fallback search using LIKE operator.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + top_k = kwargs.get("top_k", 10) document_ids_filter = kwargs.get("document_ids_filter") @@ -735,62 +1014,37 @@ class ClickzettaVector(BaseVector): """ documents = [] - connection = self._ensure_connection() - with connection.cursor() as cursor: - # Use hints parameter for LIKE search optimization - like_hints = { - 'hints': { - 'sdk.job.timeout': 20, # Timeout for LIKE search - 'cz.sql.job.fast.mode': True + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for LIKE search optimization + like_hints = { + "hints": { + "sdk.job.timeout": 20, # Timeout for LIKE search + "cz.sql.job.fast.mode": True, + } } - } - cursor.execute(search_sql, parameters=like_hints) - results = cursor.fetchall() + cursor.execute(search_sql, like_hints) + results = cursor.fetchall() - for row in results: - # Parse metadata from JSON string (may be double-encoded) - try: - if row[2]: - metadata = json.loads(row[2]) + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if not isinstance(metadata, dict): - metadata = {} - else: - metadata = {} - except (json.JSONDecodeError, TypeError) as e: - logger.error("JSON parsing failed: %s", e) - # Fallback: extract document_id with regex - import re - doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) - metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - - # Ensure required fields are set - metadata["doc_id"] = row[0] # segment id - - # Ensure document_id exists (critical for Dify's format_retrieval_documents) - if "document_id" not in metadata: - metadata["document_id"] = row[0] # fallback to segment id - - metadata["score"] = 0.5 # Lower score for LIKE search - doc = Document(page_content=row[1], metadata=metadata) - documents.append(doc) + metadata["score"] = 0.5 # Lower score for LIKE search + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) return documents def delete(self) -> None: """Delete the entire collection.""" - connection = self._ensure_connection() - with connection.cursor() as cursor: - cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") - + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") def _format_vector_simple(self, vector: list[float]) -> str: """Simple vector formatting for SQL queries.""" - return ','.join(map(str, vector)) + return ",".join(map(str, vector)) def _safe_doc_id(self, doc_id: str) -> str: """Ensure doc_id is safe for SQL and doesn't contain special characters.""" @@ -799,13 +1053,12 @@ class ClickzettaVector(BaseVector): # Remove or replace potentially problematic characters safe_id = str(doc_id) # Only allow alphanumeric, hyphens, underscores - safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_') + safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_") if not safe_id: # If all characters were removed return str(uuid.uuid4()) return safe_id[:255] # Limit length - class ClickzettaVectorFactory(AbstractVectorFactory): """Factory for creating Clickzetta vector instances.""" @@ -831,4 +1084,3 @@ class ClickzettaVectorFactory(AbstractVectorFactory): collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() return ClickzettaVector(collection_name=collection_name, config=config) - diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3aa4b67a78..0517d5a6d1 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -246,6 +246,10 @@ class TencentVector(BaseVector): return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = Filter(Filter.In("metadata.document_id", document_ids_filter)) if not self._enable_hybrid_search: return [] res = self._client.hybrid_search( @@ -269,6 +273,7 @@ class TencentVector(BaseVector): ), retrieve_vector=False, limit=kwargs.get("top_k", 4), + filter=filter, ) score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 14363de7d4..0eff7c186a 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, "storage") + content = self.parse_docx(self.file_path) return [ Document( page_content=content, @@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor): paragraph_content.append(run.text) return "".join(paragraph_content).strip() - def _parse_paragraph(self, paragraph, image_map): - paragraph_content = [] - for run in paragraph.runs: - if run.element.xpath(".//a:blip"): - for blip in run.element.xpath(".//a:blip"): - embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if embed_id: - rel_target = run.part.rels[embed_id].target_ref - if rel_target in image_map: - paragraph_content.append(image_map[rel_target]) - if run.text.strip(): - paragraph_content.append(run.text.strip()) - return " ".join(paragraph_content) if paragraph_content else "" - - def parse_docx(self, docx_path, image_folder): + def parse_docx(self, docx_path): doc = DocxDocument(docx_path) - os.makedirs(image_folder, exist_ok=True) content = [] diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 178f2b9689..83444c02d8 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -29,7 +29,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatorUserRole @@ -247,7 +247,8 @@ class ToolEngine: ) elif response.type == ToolInvokeMessage.MessageType.JSON: result += json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object), + ensure_ascii=False, ) else: result += str(response.message) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 9998de0465..ac12d83ef2 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,14 @@ import logging from collections.abc import Generator +from datetime import date, datetime +from decimal import Decimal from mimetypes import guess_extension -from typing import Optional +from typing import Optional, cast +from uuid import UUID + +import numpy as np +import pytz +from flask_login import current_user from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage @@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) +def safe_json_value(v): + if isinstance(v, datetime): + tz_name = getattr(current_user, "timezone", None) if current_user is not None else None + if not tz_name: + tz_name = "UTC" + return v.astimezone(pytz.timezone(tz_name)).isoformat() + elif isinstance(v, date): + return v.isoformat() + elif isinstance(v, UUID): + return str(v) + elif isinstance(v, Decimal): + return float(v) + elif isinstance(v, bytes): + try: + return v.decode("utf-8") + except UnicodeDecodeError: + return v.hex() + elif isinstance(v, memoryview): + return v.tobytes().hex() + elif isinstance(v, np.ndarray): + return v.tolist() + elif isinstance(v, dict): + return safe_json_dict(v) + elif isinstance(v, list | tuple | set): + return [safe_json_value(i) for i in v] + else: + return v + + +def safe_json_dict(d): + if not isinstance(d, dict): + raise TypeError("safe_json_dict() expects a dictionary (dict) as input") + return {k: safe_json_value(v) for k, v in d.items()} + + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages( @@ -113,6 +155,12 @@ class ToolFileMessageTransformer: ) else: yield message + + elif message.type == ToolInvokeMessage.MessageType.JSON: + if isinstance(message.message, ToolInvokeMessage.JsonMessage): + json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) + json_msg.json_object = safe_json_value(json_msg.json_object) + yield message else: yield message diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 13274f4e0e..a99f5eece3 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -119,6 +119,13 @@ class ObjectSegment(Segment): class ArraySegment(Segment): + @property + def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" + return super().text + @property def markdown(self) -> str: items = [] @@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment): @property def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" return json.dumps(self.value, ensure_ascii=False) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 23512c8ce4..a61e6ba4ac 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: """Extract text from a file based on its file extension.""" match file_extension: - case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + case ( + ".txt" + | ".markdown" + | ".md" + | ".html" + | ".htm" + | ".xml" + | ".c" + | ".h" + | ".cpp" + | ".hpp" + | ".cc" + | ".cxx" + | ".c++" + | ".py" + | ".js" + | ".ts" + | ".jsx" + | ".tsx" + | ".java" + | ".php" + | ".rb" + | ".go" + | ".rs" + | ".swift" + | ".kt" + | ".scala" + | ".sh" + | ".bash" + | ".bat" + | ".ps1" + | ".sql" + | ".r" + | ".m" + | ".pl" + | ".lua" + | ".vim" + | ".asm" + | ".s" + | ".css" + | ".scss" + | ".less" + | ".sass" + | ".ini" + | ".cfg" + | ".conf" + | ".toml" + | ".env" + | ".log" + | ".vtt" + ): return _extract_text_from_plain_text(file_content) case ".json": return _extract_text_from_json(file_content) @@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) return _extract_text_from_eml(file_content) case ".msg": return _extract_text_from_msg(file_content) - case ".vtt": - return _extract_text_from_vtt(file_content) case ".properties": return _extract_text_from_properties(file_content) case _: diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 598e5bc9e3..c72032701f 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,5 +1,4 @@ import hashlib -import os from typing import Union from Crypto.Cipher import AES @@ -18,7 +17,7 @@ def generate_key_pair(tenant_id: str) -> str: pem_private = private_key.export_key() pem_public = public_key.export_key() - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" storage.save(filepath, pem_private) @@ -48,7 +47,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}" private_key = redis_client.get(cache_key) diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 2298acf6eb..2b74fb2dd0 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import text -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -27,8 +27,8 @@ def clean_embedding_cache_task(): .all() ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] - except NotFound: - break + except SQLAlchemyError: + raise if embedding_ids: for embedding_id in embedding_ids: db.session.execute( diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 4c35745959..a896c818a5 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -3,7 +3,7 @@ import logging import time import click -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -42,8 +42,8 @@ def clean_messages(): .all() ) - except NotFound: - break + except SQLAlchemyError: + raise if not messages: break for message in messages: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 7887835bc5..940da5309e 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import func, select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -65,8 +65,8 @@ def clean_unused_datasets_task(): datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: - break + except SQLAlchemyError: + raise if datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: @@ -146,8 +146,8 @@ def clean_unused_datasets_task(): ) datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: - break + except SQLAlchemyError: + raise if datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 692a3639cd..713c4c6782 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -50,12 +50,16 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - # Check if include_ids is not None and not empty to avoid WHERE false condition - if include_ids is not None and len(include_ids) > 0: + # Check if include_ids is not None to apply filter + if include_ids is not None: + if len(include_ids) == 0: + # If include_ids is empty, return empty result + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) stmt = stmt.where(Conversation.id.in_(include_ids)) - # Check if exclude_ids is not None and not empty to avoid WHERE false condition - if exclude_ids is not None and len(exclude_ids) > 0: - stmt = stmt.where(~Conversation.id.in_(exclude_ids)) + # Check if exclude_ids is not None to apply filter + if exclude_ids is not None: + if len(exclude_ids) > 0: + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 2d62d49d91..6bbb3bca04 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -256,7 +256,7 @@ class WorkflowDraftVariableService: def _reset_node_var_or_sys_var( self, workflow: Workflow, variable: WorkflowDraftVariable ) -> WorkflowDraftVariable | None: - # If a variable does not allow updating, it makes no sence to resetting it. + # If a variable does not allow updating, it makes no sense to reset it. if not variable.editable: return variable # No execution record for this variable, delete the variable instead. @@ -478,7 +478,7 @@ def _batch_upsert_draft_variable( "node_execution_id": stmt.excluded.node_execution_id, }, ) - elif _UpsertPolicy.IGNORE: + elif policy == _UpsertPolicy.IGNORE: stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) else: raise Exception("Invalid value for update policy.") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index fe6d613b1c..69e5df0253 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -56,15 +56,24 @@ def clean_dataset_task( documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() + # Fix: Always clean vector database resources regardless of document existence + # This ensures all 33 vector databases properly drop tables/collections/indices + if doc_form is None: + # Use default paragraph index type for empty datasets to enable vector database cleanup + from core.rag.index_processor.constant.index_type import IndexType + + doc_form = IndexType.PARAGRAPH_INDEX + logging.info( + click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow") + ) + + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + if documents is None or len(documents) == 0: logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) else: logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) - # Specify the index type before initializing the index processor - if doc_form is None: - raise ValueError("Index type must be specified.") - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..2d0ceac760 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,168 @@ +""" +Unit tests for App description validation functions. + +This test module validates the 400-character limit enforcement +for App descriptions across all creation and editing endpoints. +""" + +import os +import sys + +import pytest + +# Add the API root to Python path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +class TestAppDescriptionValidationUnit: + """Unit tests for description validation function""" + + def test_validate_description_length_function(self): + """Test the _validate_description_length function directly""" + from controllers.console.app.app import _validate_description_length + + # Test valid descriptions + assert _validate_description_length("") == "" + assert _validate_description_length("x" * 400) == "x" * 400 + assert _validate_description_length(None) is None + + # Test invalid descriptions + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 401) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 500) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 1000) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_validation_consistency_with_dataset(self): + """Test that App and Dataset validation functions are consistent""" + from controllers.console.app.app import _validate_description_length as app_validate + from controllers.console.datasets.datasets import _validate_description_length as dataset_validate + from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + # Test same valid inputs + valid_desc = "x" * 400 + assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc) + assert app_validate("") == dataset_validate("") == service_dataset_validate("") + assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None) + + # Test same invalid inputs produce same error + invalid_desc = "x" * 401 + + app_error = None + dataset_error = None + service_dataset_error = None + + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + assert app_error == dataset_error == service_dataset_error + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test exact boundary + exactly_400 = "x" * 400 + assert _validate_description_length(exactly_400) == exactly_400 + + # Test just over boundary + just_over_400 = "x" * 401 + with pytest.raises(ValueError): + _validate_description_length(just_over_400) + + # Test just under boundary + just_under_400 = "x" * 399 + assert _validate_description_length(just_under_400) == just_under_400 + + def test_edge_cases(self): + """Test edge cases for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test None input + assert _validate_description_length(None) is None + + # Test empty string + assert _validate_description_length("") == "" + + # Test single character + assert _validate_description_length("a") == "a" + + # Test unicode characters + unicode_desc = "测试" * 200 # 400 characters in Chinese + assert _validate_description_length(unicode_desc) == unicode_desc + + # Test unicode over limit + unicode_over = "测试" * 201 # 402 characters + with pytest.raises(ValueError): + _validate_description_length(unicode_over) + + def test_whitespace_handling(self): + """Test how validation handles whitespace""" + from controllers.console.app.app import _validate_description_length + + # Test description with spaces + spaces_400 = " " * 400 + assert _validate_description_length(spaces_400) == spaces_400 + + # Test description with spaces over limit + spaces_401 = " " * 401 + with pytest.raises(ValueError): + _validate_description_length(spaces_401) + + # Test mixed content + mixed_400 = "a" * 200 + " " * 200 + assert _validate_description_length(mixed_400) == mixed_400 + + # Test mixed over limit + mixed_401 = "a" * 200 + " " * 201 + with pytest.raises(ValueError): + _validate_description_length(mixed_401) + + +if __name__ == "__main__": + # Run tests directly + import traceback + + test_instance = TestAppDescriptionValidationUnit() + test_methods = [method for method in dir(test_instance) if method.startswith("test_")] + + passed = 0 + failed = 0 + + for test_method in test_methods: + try: + print(f"Running {test_method}...") + getattr(test_instance, test_method)() + print(f"✅ {test_method} PASSED") + passed += 1 + except Exception as e: + print(f"❌ {test_method} FAILED: {str(e)}") + traceback.print_exc() + failed += 1 + + print(f"\n📊 Test Results: {passed} passed, {failed} failed") + + if failed == 0: + print("🎉 All tests passed!") + else: + print("💥 Some tests failed!") + sys.exit(1) diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py index 0aa92bc84a..8b57132772 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py +++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py @@ -39,10 +39,7 @@ class TestClickzettaVector(AbstractVectorTest): ) with setup_mock_redis(): - vector = ClickzettaVector( - collection_name="test_collection_" + str(os.getpid()), - config=config - ) + vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) yield vector @@ -114,7 +111,7 @@ class TestClickzettaVector(AbstractVectorTest): "category": "technical" if i % 2 == 0 else "general", "document_id": f"doc_{i // 3}", # Group documents "importance": i, - } + }, ) documents.append(doc) # Create varied embeddings @@ -124,22 +121,14 @@ class TestClickzettaVector(AbstractVectorTest): # Test vector search with document filter query_vector = [0.5, 1.0, 1.5, 2.0] - results = vector_store.search_by_vector( - query_vector, - top_k=5, - document_ids_filter=["doc_0", "doc_1"] - ) + results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"]) assert len(results) > 0 # All results should belong to doc_0 or doc_1 groups for result in results: assert result.metadata["document_id"] in ["doc_0", "doc_1"] # Test score threshold - results = vector_store.search_by_vector( - query_vector, - top_k=10, - score_threshold=0.5 - ) + results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5) # Check that all results have a score above threshold for result in results: assert result.metadata.get("score", 0) >= 0.5 @@ -154,7 +143,7 @@ class TestClickzettaVector(AbstractVectorTest): for i in range(batch_size): doc = Document( page_content=f"Batch document {i}: This is a test document for batch processing.", - metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"} + metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}, ) documents.append(doc) embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)]) @@ -179,7 +168,7 @@ class TestClickzettaVector(AbstractVectorTest): # Test special characters in content special_doc = Document( page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline", - metadata={"doc_id": "special_doc", "test": "edge_case"} + metadata={"doc_id": "special_doc", "test": "edge_case"}, ) embeddings = [[0.1, 0.2, 0.3, 0.4]] @@ -199,20 +188,18 @@ class TestClickzettaVector(AbstractVectorTest): # Prepare documents with various language content documents = [ Document( - page_content="云器科技提供强大的Lakehouse解决方案", - metadata={"doc_id": "cn_doc_1", "lang": "chinese"} + page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"} ), Document( page_content="Clickzetta provides powerful Lakehouse solutions", - metadata={"doc_id": "en_doc_1", "lang": "english"} + metadata={"doc_id": "en_doc_1", "lang": "english"}, ), Document( - page_content="Lakehouse是现代数据架构的重要组成部分", - metadata={"doc_id": "cn_doc_2", "lang": "chinese"} + page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"} ), Document( page_content="Modern data architecture includes Lakehouse technology", - metadata={"doc_id": "en_doc_2", "lang": "english"} + metadata={"doc_id": "en_doc_2", "lang": "english"}, ), ] diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py index 5f2e290ad4..ef54eaa174 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py +++ b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py @@ -2,6 +2,7 @@ """ Test Clickzetta integration in Docker environment """ + import os import time @@ -20,7 +21,7 @@ def test_clickzetta_connection(): service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"), vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"), - database=os.getenv("CLICKZETTA_SCHEMA", "dify") + database=os.getenv("CLICKZETTA_SCHEMA", "dify"), ) with conn.cursor() as cursor: @@ -36,7 +37,7 @@ def test_clickzetta_connection(): # Check if test collection exists test_collection = "collection_test_dataset" - if test_collection in [t[1] for t in tables if t[0] == 'dify']: + if test_collection in [t[1] for t in tables if t[0] == "dify"]: cursor.execute(f"DESCRIBE dify.{test_collection}") columns = cursor.fetchall() print(f"✓ Table structure for {test_collection}:") @@ -55,6 +56,7 @@ def test_clickzetta_connection(): print(f"✗ Connection test failed: {e}") return False + def test_dify_api(): """Test Dify API with Clickzetta backend""" print("\n=== Testing Dify API ===") @@ -83,6 +85,7 @@ def test_dify_api(): print(f"✗ API test failed: {e}") return False + def verify_table_structure(): """Verify the table structure meets Dify requirements""" print("\n=== Verifying Table Structure ===") @@ -91,15 +94,10 @@ def verify_table_structure(): "id": "VARCHAR", "page_content": "VARCHAR", "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta - "vector": "ARRAY" + "vector": "ARRAY", } - expected_metadata_fields = [ - "doc_id", - "doc_hash", - "document_id", - "dataset_id" - ] + expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] print("✓ Expected table structure:") for col, dtype in expected_columns.items(): @@ -117,6 +115,7 @@ def verify_table_structure(): return True + def main(): """Run all tests""" print("Starting Clickzetta integration tests for Dify Docker\n") @@ -137,9 +136,9 @@ def main(): results.append((test_name, False)) # Summary - print("\n" + "="*50) + print("\n" + "=" * 50) print("Test Summary:") - print("="*50) + print("=" * 50) passed = sum(1 for _, success in results if success) total = len(results) @@ -161,5 +160,6 @@ def main(): print("\n⚠️ Some tests failed. Please check the errors above.") return 1 + if __name__ == "__main__": exit(main()) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..0ab5f398e3 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -0,0 +1,1252 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from models.model import MessageAnnotation +from services.annotation_service import AppAnnotationService +from services.app_service import AppService + + +class TestAnnotationService: + """Integration tests for AnnotationService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.annotation_service.FeatureService") as mock_feature_service, + patch("services.annotation_service.add_annotation_to_index_task") as mock_add_task, + patch("services.annotation_service.update_annotation_to_index_task") as mock_update_task, + patch("services.annotation_service.delete_annotation_index_task") as mock_delete_task, + patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, + patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, + patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, + patch("services.annotation_service.current_user") as mock_current_user, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + mock_add_task.delay.return_value = None + mock_update_task.delay.return_value = None + mock_delete_task.delay.return_value = None + mock_enable_task.delay.return_value = None + mock_disable_task.delay.return_value = None + mock_batch_import_task.delay.return_value = None + + yield { + "account_feature_service": mock_account_feature_service, + "feature_service": mock_feature_service, + "add_task": mock_add_task, + "update_task": mock_update_task, + "delete_task": mock_delete_task, + "enable_task": mock_enable_task, + "disable_task": mock_disable_task, + "batch_import_task": mock_batch_import_task, + "current_user": mock_current_user, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Setup current_user mock + self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id) + + return app, account + + def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id): + """ + Helper method to mock the current user for testing. + """ + mock_external_service_dependencies["current_user"].id = account_id + mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + + def _create_test_conversation(self, app, account, fake): + """ + Helper method to create a test conversation with all required fields. + """ + from extensions.ext_database import db + from models.model import Conversation + + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=fake.sentence(), + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(conversation) + db.session.flush() + return conversation + + def _create_test_message(self, app, conversation, account, fake): + """ + Helper method to create a test message with all required fields. + """ + import json + + from extensions.ext_database import db + from models.model import Message + + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query=fake.sentence(), + message=json.dumps([{"role": "user", "text": fake.sentence()}]), + message_tokens=0, + message_unit_price=0, + message_price_unit=0.001, + answer=fake.text(max_nb_chars=200), + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0.001, + parent_message_id=None, + provider_response_latency=0, + total_price=0, + currency="USD", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(message) + db.session.commit() + return message + + def test_insert_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify annotation was saved to database + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_insert_app_annotation_directly_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test direct insertion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) + + def test_update_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["update_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_new( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating new annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_update( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test updating existing annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create initial annotation + initial_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + initial_annotation = AppAnnotationService.up_insert_app_annotation_from_message(initial_args, app.id) + + # Update the annotation + updated_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.up_insert_app_annotation_from_message(updated_args, app.id) + + # Verify annotation was updated correctly (same ID) + assert updated_annotation.id == initial_annotation.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.question != initial_args["question"] + assert updated_annotation.content != initial_args["answer"] + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) + + def test_get_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful retrieval of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Get annotation list + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="" + ) + + # Verify results + assert len(annotation_list) == 3 + assert total == 3 + + # Verify all annotations belong to the correct app + for annotation in annotation_list: + assert annotation.app_id == app.id + assert annotation.account_id == account.id + + def test_get_annotation_list_by_app_id_with_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list with keyword search. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with specific keywords + unique_keyword = fake.word() + annotation_args = { + "question": f"Question with {unique_keyword} keyword", + "answer": f"Answer with {unique_keyword} keyword", + } + AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Create another annotation without the keyword + other_args = { + "question": "Question without keyword", + "answer": "Answer without keyword", + } + AppAnnotationService.insert_app_annotation_directly(other_args, app.id) + + # Search with keyword + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword=unique_keyword + ) + + # Verify only matching annotations are returned + assert len(annotation_list) == 1 + assert total == 1 + assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + + def test_get_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to get annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") + + def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + from extensions.ext_database import db + + deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["delete_task"].delay.assert_not_called() + + def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deletion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + annotation_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to delete annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) + + def test_delete_app_annotation_annotation_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test deletion of app annotation when annotation is not found. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + non_existent_annotation_id = fake.uuid4() + + # Try to delete non-existent annotation + with pytest.raises(NotFound, match="Annotation not found"): + AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) + + def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["enable_task"].delay.assert_called_once() + + def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Disable annotation + result = AppAnnotationService.disable_app_annotation(app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["disable_task"].delay.assert_called_once() + + def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test enabling app annotation when job is already cached. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return cached job + from extensions.ext_redis import redis_client + + cached_job_id = fake.uuid4() + enable_app_annotation_key = f"enable_app_annotation_{app.id}" + redis_client.set(enable_app_annotation_key, cached_job_id) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation (should return cached job) + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify cached result + assert cached_job_id == result["job_id"].decode("utf-8") + assert result["job_status"] == "processing" + + # Verify task was not called again + mock_external_service_dependencies["enable_task"].delay.assert_not_called() + + # Clean up + redis_client.delete(enable_app_annotation_key) + + def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation hit histories. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Add some hit histories + for i in range(3): + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=f"Query {i}: {fake.sentence()}", + user_id=account.id, + message_id=fake.uuid4(), + from_source="console", + score=0.8 + (i * 0.1), + ) + + # Get hit histories + hit_histories, total = AppAnnotationService.get_annotation_hit_histories( + app.id, annotation.id, page=1, limit=10 + ) + + # Verify results + assert len(hit_histories) == 3 + assert total == 3 + + # Verify all histories belong to the correct annotation + for history in hit_histories: + assert history.annotation_id == annotation.id + assert history.app_id == app.id + assert history.account_id == account.id + + def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful addition of annotation history. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get initial hit count + initial_hit_count = annotation.hit_count + + # Add annotation history + query = fake.sentence() + message_id = fake.uuid4() + score = 0.85 + + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=query, + user_id=account.id, + message_id=message_id, + from_source="console", + score=score, + ) + + # Verify hit count was incremented + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.hit_count == initial_hit_count + 1 + + # Verify history was created + from models.model import AppAnnotationHitHistory + + history = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id + ) + .first() + ) + + assert history is not None + assert history.app_id == app.id + assert history.account_id == account.id + assert history.question == query + assert history.score == score + assert history.source == "console" + + def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation by ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + created_annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get annotation by ID + retrieved_annotation = AppAnnotationService.get_annotation_by_id(created_annotation.id) + + # Verify annotation was retrieved correctly + assert retrieved_annotation is not None + assert retrieved_annotation.id == created_annotation.id + assert retrieved_annotation.app_id == app.id + assert retrieved_annotation.question == annotation_args["question"] + assert retrieved_annotation.content == annotation_args["answer"] + assert retrieved_annotation.account_id == account.id + + def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful batch import of app annotations. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = False + + # Mock pandas to return expected DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() + + def test_batch_import_app_annotations_empty_file( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import with empty CSV file. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create empty CSV content + csv_content = "" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return empty DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame() + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "empty" in result["error_msg"].lower() + + def test_batch_import_app_annotations_quota_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import when quota is exceeded. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Mock FeatureService to return billing enabled with quota exceeded + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = True + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.limit = 1 + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.size = 0 + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "limit" in result["error_msg"].lower() + + def test_get_app_annotation_setting_by_app_id_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting enabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Get annotation setting + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.8 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + def test_get_app_annotation_setting_by_app_id_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting disabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Get annotation setting (no setting exists) + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is False + + def test_update_app_annotation_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful update of app annotation setting. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Update annotation setting + update_args = { + "score_threshold": 0.9, + } + + result = AppAnnotationService.update_app_annotation_setting(app.id, annotation_setting.id, update_args) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.9 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + # Verify database was updated + db.session.refresh(annotation_setting) + assert annotation_setting.score_threshold == 0.9 + + def test_export_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful export of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Export annotation list + exported_annotations = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Verify results + assert len(exported_annotations) == 3 + + # Verify all annotations belong to the correct app and are ordered by created_at desc + for i, annotation in enumerate(exported_annotations): + assert annotation.app_id == app.id + assert annotation.account_id == account.id + if i > 0: + # Verify descending order (newer first) + assert annotation.created_at <= exported_annotations[i - 1].created_at + + def test_export_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test export of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to export annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) + + def test_insert_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_update_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Reset mock to clear previous calls + mock_external_service_dependencies["update_task"].delay.reset_mock() + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called + mock_external_service_dependencies["update_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["update_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == updated_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_delete_app_annotation_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful deletion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Reset mock to clear previous calls + mock_external_service_dependencies["delete_task"].delay.reset_mock() + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called + mock_external_service_dependencies["delete_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["delete_task"].delay.call_args[0] + assert call_args[0] == annotation_id # annotation_id + assert call_args[1] == app.id # app_id + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == collection_binding.id # collection_binding_id + + def test_up_insert_app_annotation_from_message_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py new file mode 100644 index 0000000000..38f532fd64 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -0,0 +1,487 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from models.api_based_extension import APIBasedExtension +from services.account_service import AccountService, TenantService +from services.api_based_extension_service import APIBasedExtensionService + + +class TestAPIBasedExtensionService: + """Integration tests for APIBasedExtensionService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + + # Mock successful ping response + mock_requestor_instance = mock_requestor.return_value + mock_requestor_instance.request.return_value = {"result": "pong"} + + yield { + "account_feature_service": mock_account_feature_service, + "requestor": mock_requestor, + "requestor_instance": mock_requestor_instance, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + return account, tenant + + def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful saving of API-based extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Save extension + saved_extension = APIBasedExtensionService.save(extension_data) + + # Verify extension was saved correctly + assert saved_extension.id is not None + assert saved_extension.tenant_id == tenant.id + assert saved_extension.name == extension_data.name + assert saved_extension.api_endpoint == extension_data.api_endpoint + assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved + assert saved_extension.created_at is not None + + # Verify extension was saved to database + from extensions.ext_database import db + + db.session.refresh(saved_extension) + assert saved_extension.id is not None + + # Verify ping connection was called + mock_external_service_dependencies["requestor_instance"].request.assert_called_once() + + def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation errors when saving extension with invalid data. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test empty name + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = "" + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test empty api_endpoint + extension_data.name = fake.company() + extension_data.api_endpoint = "" + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test empty api_key + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = "" + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService.save(extension_data) + + def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of all extensions by tenant ID. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create multiple extensions + extensions = [] + for i in range(3): + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = f"Extension {i}: {fake.company()}" + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + saved_extension = APIBasedExtensionService.save(extension_data) + extensions.append(saved_extension) + + # Get all extensions for tenant + extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + + # Verify results + assert len(extension_list) == 3 + + # Verify all extensions belong to the correct tenant and are ordered by created_at desc + for i, extension in enumerate(extension_list): + assert extension.tenant_id == tenant.id + assert extension.api_key is not None # Should be decrypted + if i > 0: + # Verify descending order (newer first) + assert extension.created_at <= extension_list[i - 1].created_at + + def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of extension by tenant ID and extension ID. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create an extension + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Get extension by ID + retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + + # Verify extension was retrieved correctly + assert retrieved_extension is not None + assert retrieved_extension.id == created_extension.id + assert retrieved_extension.tenant_id == tenant.id + assert retrieved_extension.name == extension_data.name + assert retrieved_extension.api_endpoint == extension_data.api_endpoint + assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted + assert retrieved_extension.created_at is not None + + def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extension when extension is not found. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + non_existent_extension_id = fake.uuid4() + + # Try to get non-existent extension + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) + + def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create an extension first + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + extension_id = created_extension.id + + # Delete the extension + APIBasedExtensionService.delete(created_extension) + + # Verify extension was deleted + from extensions.ext_database import db + + deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first() + assert deleted_extension is None + + def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when saving extension with duplicate name. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first extension + extension_data1 = APIBasedExtension() + extension_data1.tenant_id = tenant.id + extension_data1.name = "Test Extension" + extension_data1.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data1.api_key = fake.password(length=20) + + APIBasedExtensionService.save(extension_data1) + + # Try to create second extension with same name + extension_data2 = APIBasedExtension() + extension_data2.tenant_id = tenant.id + extension_data2.name = "Test Extension" # Same name + extension_data2.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data2.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(extension_data2) + + def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful update of existing extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create initial extension + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Save original values for later comparison + original_name = created_extension.name + original_endpoint = created_extension.api_endpoint + + # Update the extension + new_name = fake.company() + new_endpoint = f"https://{fake.domain_name()}/api" + new_api_key = fake.password(length=20) + + created_extension.name = new_name + created_extension.api_endpoint = new_endpoint + created_extension.api_key = new_api_key + + updated_extension = APIBasedExtensionService.save(created_extension) + + # Verify extension was updated correctly + assert updated_extension.id == created_extension.id + assert updated_extension.tenant_id == tenant.id + assert updated_extension.name == new_name + assert updated_extension.api_endpoint == new_endpoint + + # Verify original values were changed + assert updated_extension.name != original_name + assert updated_extension.api_endpoint != original_endpoint + + # Verify ping connection was called for both create and update + assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2 + + # Verify the update by retrieving the extension again + retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + assert retrieved_extension.name == new_name + assert retrieved_extension.api_endpoint == new_endpoint + assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved + + def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test connection error when saving extension with invalid endpoint. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock connection error + mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError( + "connection error: request timeout" + ) + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = "https://invalid-endpoint.com/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with connection error + with pytest.raises(ValueError, match="connection error: request timeout"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_invalid_api_key_length( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test validation error when saving extension with API key that is too short. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup extension data with short API key + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = "1234" # Less than 5 characters + + # Try to save extension with short API key + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation errors when saving extension with empty required fields. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test with None values + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = None + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test with None api_endpoint + extension_data.name = fake.company() + extension_data.api_endpoint = None + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test with None api_key + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = None + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService.save(extension_data) + + def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extensions when no extensions exist for tenant. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Get all extensions for tenant (none exist) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + + # Verify empty list is returned + assert len(extension_list) == 0 + assert extension_list == [] + + def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when ping response is invalid. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock invalid ping response + mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"} + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with invalid ping response + with pytest.raises(ValueError, match="{'result': 'invalid'}"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when ping response is missing result field. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock ping response without result field + mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"} + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with missing ping result + with pytest.raises(ValueError, match="{'status': 'ok'}"): + APIBasedExtensionService.save(extension_data) + + def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extension when tenant ID doesn't match. + """ + fake = Faker() + account1, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create second account and tenant + account2, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create extension in first tenant + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant1.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Try to get extension with wrong tenant ID + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py new file mode 100644 index 0000000000..f2bd9f8084 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -0,0 +1,473 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +import yaml +from faker import Faker + +from models.model import App, AppModelConfig +from services.account_service import AccountService, TenantService +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_service import AppService + + +class TestAppDslService: + """Integration tests for AppDslService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_dsl_service.WorkflowService") as mock_workflow_service, + patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service, + patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service, + patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy, + patch("services.app_dsl_service.redis_client") as mock_redis_client, + patch("services.app_dsl_service.app_was_created") as mock_app_was_created, + patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns + mock_workflow_service.return_value.get_draft_workflow.return_value = None + mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock() + mock_dependencies_service.generate_latest_dependencies.return_value = [] + mock_dependencies_service.get_leaked_dependencies.return_value = [] + mock_dependencies_service.generate_dependencies.return_value = [] + mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None + mock_ssrf_proxy.get.return_value.content = b"test content" + mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None + mock_redis_client.setex.return_value = None + mock_redis_client.get.return_value = None + mock_redis_client.delete.return_value = None + mock_app_was_created.send.return_value = None + mock_app_model_config_was_updated.send.return_value = None + + # Mock ModelManager for app service + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + # Mock FeatureService and EnterpriseService + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + yield { + "workflow_service": mock_workflow_service, + "dependencies_service": mock_dependencies_service, + "draft_variable_service": mock_draft_variable_service, + "ssrf_proxy": mock_ssrf_proxy, + "redis_client": mock_redis_client, + "app_was_created": mock_app_was_created, + "app_model_config_was_updated": mock_app_model_config_was_updated, + "model_manager": mock_model_manager, + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + with patch("services.account_service.FeatureService") as mock_account_feature_service: + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"): + """ + Helper method to create simple YAML content for testing. + """ + yaml_data = { + "version": "0.3.0", + "kind": "app", + "app": { + "name": app_name, + "mode": app_mode, + "icon": "🤖", + "icon_background": "#FFEAD5", + "description": "Test app description", + "use_icon_as_answer_icon": False, + }, + "model_config": { + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + "top_p": 1.0, + }, + }, + "pre_prompt": "You are a helpful assistant.", + "prompt_type": "simple", + }, + } + return yaml.dump(yaml_data, allow_unicode=True) + + def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app import from YAML content. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Import app + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + name="Imported App", + description="Imported app description", + ) + + # Verify import result + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + assert result.app_mode == "chat" + assert result.imported_dsl_version == "0.3.0" + assert result.error == "" + + # Verify app was created in database + imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first() + assert imported_app is not None + assert imported_app.name == "Imported App" + assert imported_app.description == "Imported app description" + assert imported_app.mode == "chat" + assert imported_app.tenant_id == account.current_tenant_id + assert imported_app.created_by == account.id + + # Verify model config was created + model_config = ( + db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first() + ) + assert model_config is not None + # The provider and model_id are stored in the model field as JSON + model_dict = model_config.model_dict + assert model_dict["provider"] == "openai" + assert model_dict["name"] == "gpt-3.5-turbo" + + def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app import from YAML URL. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content for mock response + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Setup mock response + mock_response = MagicMock() + mock_response.content = yaml_content.encode("utf-8") + mock_response.raise_for_status.return_value = None + mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response + + # Import app from URL + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/app.yaml", + name="URL Imported App", + description="App imported from URL", + ) + + # Verify import result + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + assert result.app_mode == "chat" + assert result.imported_dsl_version == "0.3.0" + assert result.error == "" + + # Verify app was created in database + imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first() + assert imported_app is not None + assert imported_app.name == "URL Imported App" + assert imported_app.description == "App imported from URL" + assert imported_app.mode == "chat" + assert imported_app.tenant_id == account.current_tenant_id + + # Verify ssrf_proxy was called + mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with( + "https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10) + ) + + def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with invalid YAML format. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create invalid YAML content + invalid_yaml = "invalid: yaml: content: [" + + # Import app with invalid YAML + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=invalid_yaml, + name="Invalid App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "Invalid YAML format" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with missing YAML content. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Import app without YAML content + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + name="Missing Content App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "yaml_content is required" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with missing YAML URL. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Import app without YAML URL + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_URL, + name="Missing URL App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "yaml_url is required" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with invalid import mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Import app with invalid mode should raise ValueError + dsl_service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"): + dsl_service.import_app( + account=account, + import_mode="invalid-mode", + yaml_content=yaml_content, + name="Invalid Mode App", + ) + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful DSL export for chat app. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create model config for the app + model_config = AppModelConfig() + model_config.id = fake.uuid4() + model_config.app_id = app.id + model_config.provider = "openai" + model_config.model_id = "gpt-3.5-turbo" + model_config.model = json.dumps( + { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + }, + } + ) + model_config.pre_prompt = "You are a helpful assistant." + model_config.prompt_type = "simple" + model_config.created_by = account.id + model_config.updated_by = account.id + + # Set the app_model_config_id to link the config + app.app_model_config_id = model_config.id + + db_session_with_containers.add(model_config) + db_session_with_containers.commit() + + # Export DSL + exported_dsl = AppDslService.export_dsl(app, include_secret=False) + + # Parse exported YAML + exported_data = yaml.safe_load(exported_dsl) + + # Verify exported data structure + assert exported_data["kind"] == "app" + assert exported_data["app"]["name"] == app.name + assert exported_data["app"]["mode"] == app.mode + assert exported_data["app"]["icon"] == app.icon + assert exported_data["app"]["icon_background"] == app.icon_background + assert exported_data["app"]["description"] == app.description + + # Verify model config was exported + assert "model_config" in exported_data + # The exported model_config structure may be different from the database structure + # Check that the model config exists and has the expected content + assert exported_data["model_config"] is not None + + # Verify dependencies were exported + assert "dependencies" in exported_data + assert isinstance(exported_data["dependencies"], list) + + def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful DSL export for workflow app. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Update app to workflow mode + app.mode = "workflow" + db_session_with_containers.commit() + + # Mock workflow service to return a workflow + mock_workflow = MagicMock() + mock_workflow.to_dict.return_value = { + "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "features": {}, + "environment_variables": [], + "conversation_variables": [], + } + mock_external_service_dependencies[ + "workflow_service" + ].return_value.get_draft_workflow.return_value = mock_workflow + + # Export DSL + exported_dsl = AppDslService.export_dsl(app, include_secret=False) + + # Parse exported YAML + exported_data = yaml.safe_load(exported_dsl) + + # Verify exported data structure + assert exported_data["kind"] == "app" + assert exported_data["app"]["name"] == app.name + assert exported_data["app"]["mode"] == "workflow" + + # Verify workflow was exported + assert "workflow" in exported_data + assert "graph" in exported_data["workflow"] + assert "nodes" in exported_data["workflow"]["graph"] + + # Verify dependencies were exported + assert "dependencies" in exported_data + assert isinstance(exported_data["dependencies"], list) + + # Verify workflow service was called + mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( + app + ) + + def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful dependency checking. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return dependencies + mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' + mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json + + # Check dependencies + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.check_dependencies(app_model=app) + + # Verify result + assert result.leaked_dependencies == [] + + # Verify Redis was queried + mock_external_service_dependencies["redis_client"].get.assert_called_once_with( + f"app_check_dependencies:{app.id}" + ) + + # Verify dependencies service was called + mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/app/test_description_validation.py b/api/tests/unit_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..178267e560 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,252 @@ +import pytest + +from controllers.console.app.app import _validate_description_length as app_validate +from controllers.console.datasets.datasets import _validate_description_length as dataset_validate +from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + +class TestDescriptionValidationUnit: + """Unit tests for description validation functions in App and Dataset APIs""" + + def test_app_validate_description_length_valid(self): + """Test App validation function with valid descriptions""" + # Empty string should be valid + assert app_validate("") == "" + + # None should be valid + assert app_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert app_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert app_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert app_validate(just_under) == just_under + + def test_app_validate_description_length_invalid(self): + """Test App validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + app_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + app_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 1000 characters should fail + very_long = "x" * 1000 + with pytest.raises(ValueError) as exc_info: + app_validate(very_long) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_dataset_validate_description_length_valid(self): + """Test Dataset validation function with valid descriptions""" + # Empty string should be valid + assert dataset_validate("") == "" + + # Short description should be valid + short_desc = "Short description" + assert dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert dataset_validate(just_under) == just_under + + def test_dataset_validate_description_length_invalid(self): + """Test Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_service_dataset_validate_description_length_valid(self): + """Test Service Dataset validation function with valid descriptions""" + # Empty string should be valid + assert service_dataset_validate("") == "" + + # None should be valid + assert service_dataset_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert service_dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert service_dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert service_dataset_validate(just_under) == just_under + + def test_service_dataset_validate_description_length_invalid(self): + """Test Service Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_app_dataset_validation_consistency(self): + """Test that App and Dataset validation functions behave identically""" + test_cases = [ + "", # Empty string + "Short description", # Normal description + "x" * 100, # Medium description + "x" * 400, # Exactly at limit + ] + + # Test valid cases produce same results + for test_desc in test_cases: + assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc) + + # Test invalid cases produce same errors + invalid_cases = [ + "x" * 401, # Just over limit + "x" * 500, # Way over limit + "x" * 1000, # Very long + ] + + for invalid_desc in invalid_cases: + app_error = None + dataset_error = None + service_dataset_error = None + + # Capture App validation error + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + # Capture Dataset validation error + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + # Capture Service Dataset validation error + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + # All should produce errors + assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters" + assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters" + error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters" + assert service_dataset_error is not None, error_msg + + # Errors should be identical + error_msg = f"Error messages should be identical for {len(invalid_desc)} characters" + assert app_error == dataset_error == service_dataset_error, error_msg + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values around the 400 character limit""" + boundary_tests = [ + (0, True), # Empty + (1, True), # Minimum + (399, True), # Just under limit + (400, True), # Exactly at limit + (401, False), # Just over limit + (402, False), # Over limit + (500, False), # Way over limit + ] + + for length, should_pass in boundary_tests: + test_desc = "x" * length + + if should_pass: + # Should not raise exception + assert app_validate(test_desc) == test_desc + assert dataset_validate(test_desc) == test_desc + assert service_dataset_validate(test_desc) == test_desc + else: + # Should raise ValueError + with pytest.raises(ValueError): + app_validate(test_desc) + with pytest.raises(ValueError): + dataset_validate(test_desc) + with pytest.raises(ValueError): + service_dataset_validate(test_desc) + + def test_special_characters(self): + """Test validation with special characters, Unicode, etc.""" + # Unicode characters + unicode_desc = "测试描述" * 100 # Chinese characters + if len(unicode_desc) <= 400: + assert app_validate(unicode_desc) == unicode_desc + assert dataset_validate(unicode_desc) == unicode_desc + assert service_dataset_validate(unicode_desc) == unicode_desc + + # Special characters + special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10 + if len(special_desc) <= 400: + assert app_validate(special_desc) == special_desc + assert dataset_validate(special_desc) == special_desc + assert service_dataset_validate(special_desc) == special_desc + + # Mixed content + mixed_desc = "Mixed content: 测试 123 !@# " * 15 + if len(mixed_desc) <= 400: + assert app_validate(mixed_desc) == mixed_desc + assert dataset_validate(mixed_desc) == mixed_desc + assert service_dataset_validate(mixed_desc) == mixed_desc + elif len(mixed_desc) > 400: + with pytest.raises(ValueError): + app_validate(mixed_desc) + with pytest.raises(ValueError): + dataset_validate(mixed_desc) + with pytest.raises(ValueError): + service_dataset_validate(mixed_desc) + + def test_whitespace_handling(self): + """Test validation with various whitespace scenarios""" + # Leading/trailing whitespace + whitespace_desc = " Description with whitespace " + if len(whitespace_desc) <= 400: + assert app_validate(whitespace_desc) == whitespace_desc + assert dataset_validate(whitespace_desc) == whitespace_desc + assert service_dataset_validate(whitespace_desc) == whitespace_desc + + # Newlines and tabs + multiline_desc = "Line 1\nLine 2\tTabbed content" + if len(multiline_desc) <= 400: + assert app_validate(multiline_desc) == multiline_desc + assert dataset_validate(multiline_desc) == multiline_desc + assert service_dataset_validate(multiline_desc) == multiline_desc + + # Only whitespace over limit + only_spaces = " " * 401 + with pytest.raises(ValueError): + app_validate(only_spaces) + with pytest.raises(ValueError): + dataset_validate(only_spaces) + with pytest.raises(ValueError): + service_dataset_validate(only_spaces) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py new file mode 100644 index 0000000000..5c484403a6 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -0,0 +1,336 @@ +""" +Unit tests for Service API File Preview endpoint +""" + +import uuid +from unittest.mock import Mock, patch + +import pytest + +from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError +from controllers.service_api.app.file_preview import FilePreviewApi +from models.model import App, EndUser, Message, MessageFile, UploadFile + + +class TestFilePreviewApi: + """Test suite for FilePreviewApi""" + + @pytest.fixture + def file_preview_api(self): + """Create FilePreviewApi instance for testing""" + return FilePreviewApi() + + @pytest.fixture + def mock_app(self): + """Mock App model""" + app = Mock(spec=App) + app.id = str(uuid.uuid4()) + app.tenant_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_end_user(self): + """Mock EndUser model""" + end_user = Mock(spec=EndUser) + end_user.id = str(uuid.uuid4()) + return end_user + + @pytest.fixture + def mock_upload_file(self): + """Mock UploadFile model""" + upload_file = Mock(spec=UploadFile) + upload_file.id = str(uuid.uuid4()) + upload_file.name = "test_file.jpg" + upload_file.mime_type = "image/jpeg" + upload_file.size = 1024 + upload_file.key = "storage/key/test_file.jpg" + upload_file.tenant_id = str(uuid.uuid4()) + return upload_file + + @pytest.fixture + def mock_message_file(self): + """Mock MessageFile model""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.upload_file_id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + return message_file + + @pytest.fixture + def mock_message(self): + """Mock Message model""" + message = Mock(spec=Message) + message.id = str(uuid.uuid4()) + message.app_id = str(uuid.uuid4()) + return message + + def test_validate_file_ownership_success( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test successful file ownership validation""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up the mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute the method + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + + # Assertions + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + def test_validate_file_ownership_file_not_found(self, file_preview_api): + """Test file ownership validation when MessageFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "File not found in message context" in str(exc_info.value) + + def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file): + """Test file ownership validation when Message not owned by app""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile found but Message not owned by app + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + None, # Message query - not found (access denied) + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "not owned by requesting app" in str(exc_info.value) + + def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message): + """Test file ownership validation when UploadFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile and Message found but UploadFile not found + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + mock_message, # Message query - found + None, # UploadFile query - not found + ] + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "Upload file record not found" in str(exc_info.value) + + def test_validate_file_ownership_tenant_mismatch( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test file ownership validation with tenant mismatch""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up tenant mismatch + mock_upload_file.tenant_id = "different_tenant_id" + mock_app.tenant_id = "app_tenant_id" + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "tenant mismatch" in str(exc_info.value) + + def test_validate_file_ownership_invalid_input(self, file_preview_api): + """Test file ownership validation with invalid input""" + + # Test with empty file_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("", "app_id") + assert "Invalid file or app identifier" in str(exc_info.value) + + # Test with empty app_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("file_id", "") + assert "Invalid file or app identifier" in str(exc_info.value) + + def test_build_file_response_basic(self, file_preview_api, mock_upload_file): + """Test basic file response building""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check response properties + assert response.mimetype == mock_upload_file.mime_type + assert response.direct_passthrough is True + assert response.headers["Content-Length"] == str(mock_upload_file.size) + assert "Cache-Control" in response.headers + + def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file): + """Test file response building with attachment flag""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True) + + # Check attachment-specific headers + assert "attachment" in response.headers["Content-Disposition"] + assert mock_upload_file.name in response.headers["Content-Disposition"] + assert response.headers["Content-Type"] == "application/octet-stream" + + def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file): + """Test file response building for audio/video files""" + mock_generator = Mock() + mock_upload_file.mime_type = "video/mp4" + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check Range support for media files + assert response.headers["Accept-Ranges"] == "bytes" + + def test_build_file_response_no_size(self, file_preview_api, mock_upload_file): + """Test file response building when size is unknown""" + mock_generator = Mock() + mock_upload_file.size = 0 # Unknown size + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Content-Length should not be set when size is unknown + assert "Content-Length" not in response.headers + + @patch("controllers.service_api.app.file_preview.storage") + def test_get_method_integration( + self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message + ): + """Test the full GET method integration (without decorator)""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + mock_generator = Mock() + mock_storage.load.return_value = mock_generator + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: + # Mock request parsing + mock_parser = Mock() + mock_parser.parse_args.return_value = {"as_attachment": False} + mock_reqparse.RequestParser.return_value = mock_parser + + # Test the core logic directly without Flask decorators + # Validate file ownership + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test file response building + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + assert response is not None + + # Verify storage was called correctly + mock_storage.load.assert_not_called() # Since we're testing components separately + + @patch("controllers.service_api.app.file_preview.storage") + def test_storage_error_handling( + self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test storage error handling in the core logic""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + # Mock storage error + mock_storage.load.side_effect = Exception("Storage error") + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries for validation + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # First validate file ownership works + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test storage error handling + with pytest.raises(Exception) as exc_info: + mock_storage.load(mock_upload_file.key, stream=True) + + assert "Storage error" in str(exc_info.value) + + @patch("controllers.service_api.app.file_preview.logger") + def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api): + """Test that unexpected errors are logged properly""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database query to raise unexpected exception + mock_db.session.query.side_effect = Exception("Unexpected database error") + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + # Verify error message + assert "File access validation failed" in str(exc_info.value) + + # Verify logging was called + mock_logger.exception.assert_called_once_with( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"}, + ) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py new file mode 100644 index 0000000000..da175e7ccd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -0,0 +1,419 @@ +"""Test conversation variable handling in AdvancedChatAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.variables import SegmentType +from factories import variable_factory +from models import ConversationVariable, Workflow + + +class TestAdvancedChatAppRunnerConversationVariables: + """Test that AdvancedChatAppRunner correctly handles conversation variables.""" + + def test_missing_conversation_variables_are_added(self): + """Test that new conversation variables added to workflow are created for existing conversations.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with two conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "existing_var", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "new_var", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow with conversation variables + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variable (only var1 exists in DB) + existing_db_var = MagicMock(spec=ConversationVariable) + existing_db_var.id = "var1" + existing_db_var.app_id = app_id + existing_db_var.conversation_id = conversation_id + existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0]) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # First query returns only existing variable + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [existing_db_var] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that the missing variable was added + assert len(added_items) == 1, "Should have added exactly one missing variable" + + # Check that the added item is the missing variable (var2) + added_var = added_items[0] + assert hasattr(added_var, "id"), "Added item should be a ConversationVariable" + # Note: Since we're mocking ConversationVariable.from_variable, + # we can't directly check the id, but we can verify add_all was called + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_no_variables_creates_all(self): + """Test that all conversation variables are created when none exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns empty list (no existing variables) + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock ConversationVariable.from_variable to return mock objects + mock_conv_vars = [] + for var in workflow_vars: + mock_cv = MagicMock() + mock_cv.id = var.id + mock_cv.to_variable.return_value = var + mock_conv_vars.append(mock_cv) + + mock_conv_var_class.from_variable.side_effect = mock_conv_vars + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that all variables were created + assert len(added_items) == 2, "Should have added both variables" + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_all_variables_exist_no_changes(self): + """Test that no changes are made when all variables already exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variables (both exist in DB) + existing_db_vars = [] + for var in workflow_vars: + db_var = MagicMock(spec=ConversationVariable) + db_var.id = var.id + db_var.app_id = app_id + db_var.conversation_id = conversation_id + db_var.to_variable = MagicMock(return_value=var) + existing_db_vars.append(db_var) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns all existing variables + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = existing_db_vars + mock_session.scalars.return_value = mock_scalars_result + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that no variables were added + assert not mock_session.add_all.called, "Session add_all should not have been called" + assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py new file mode 100644 index 0000000000..9c1c044f03 --- /dev/null +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -0,0 +1,127 @@ +import uuid +from unittest.mock import MagicMock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from services.conversation_service import ConversationService + + +class TestConversationService: + def test_pagination_with_empty_include_ids(self): + """Test that empty include_ids returns empty result""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=[], # Empty include_ids should return empty result + exclude_ids=None, + ) + + assert result.data == [] + assert result.has_more is False + assert result.limit == 20 + + def test_pagination_with_non_empty_include_ids(self): + """Test that non-empty include_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv1", "conv2"], # Non-empty include_ids + exclude_ids=None, + ) + + # Verify the where clause was called with id.in_ + assert mock_stmt.where.called + + def test_pagination_with_empty_exclude_ids(self): + """Test that empty exclude_ids doesn't filter""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=[], # Empty exclude_ids should not filter + ) + + # Result should contain the mocked conversations + assert len(result.data) == 5 + + def test_pagination_with_non_empty_exclude_ids(self): + """Test that non-empty exclude_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids + ) + + # Verify the where clause was called for exclusion + assert mock_stmt.where.called diff --git a/web/__tests__/description-validation.test.tsx b/web/__tests__/description-validation.test.tsx new file mode 100644 index 0000000000..85263b035f --- /dev/null +++ b/web/__tests__/description-validation.test.tsx @@ -0,0 +1,97 @@ +/** + * Description Validation Test + * + * Tests for the 400-character description validation across App and Dataset + * creation and editing workflows to ensure consistent validation behavior. + */ + +describe('Description Validation Logic', () => { + // Simulate backend validation function + const validateDescriptionLength = (description?: string | null) => { + if (description && description.length > 400) + throw new Error('Description cannot exceed 400 characters.') + + return description + } + + describe('Backend Validation Function', () => { + test('allows description within 400 characters', () => { + const validDescription = 'x'.repeat(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + expect(validateDescriptionLength(validDescription)).toBe(validDescription) + }) + + test('allows empty description', () => { + expect(() => validateDescriptionLength('')).not.toThrow() + expect(() => validateDescriptionLength(null)).not.toThrow() + expect(() => validateDescriptionLength(undefined)).not.toThrow() + }) + + test('rejects description exceeding 400 characters', () => { + const invalidDescription = 'x'.repeat(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + }) + }) + + describe('Backend Validation Consistency', () => { + test('App and Dataset have consistent validation limits', () => { + const maxLength = 400 + const validDescription = 'x'.repeat(maxLength) + const invalidDescription = 'x'.repeat(maxLength + 1) + + // Both should accept exactly 400 characters + expect(validDescription.length).toBe(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + + // Both should reject 401 characters + expect(invalidDescription.length).toBe(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow() + }) + + test('validation error messages are consistent', () => { + const expectedErrorMessage = 'Description cannot exceed 400 characters.' + + // This would be the error message from both App and Dataset backend validation + expect(expectedErrorMessage).toBe('Description cannot exceed 400 characters.') + + const invalidDescription = 'x'.repeat(401) + try { + validateDescriptionLength(invalidDescription) + } + catch (error) { + expect((error as Error).message).toBe(expectedErrorMessage) + } + }) + }) + + describe('Character Length Edge Cases', () => { + const testCases = [ + { length: 0, shouldPass: true, description: 'empty description' }, + { length: 1, shouldPass: true, description: '1 character' }, + { length: 399, shouldPass: true, description: '399 characters' }, + { length: 400, shouldPass: true, description: '400 characters (boundary)' }, + { length: 401, shouldPass: false, description: '401 characters (over limit)' }, + { length: 500, shouldPass: false, description: '500 characters' }, + { length: 1000, shouldPass: false, description: '1000 characters' }, + ] + + testCases.forEach(({ length, shouldPass, description }) => { + test(`handles ${description} correctly`, () => { + const testDescription = length > 0 ? 'x'.repeat(length) : '' + expect(testDescription.length).toBe(length) + + if (shouldPass) { + expect(() => validateDescriptionLength(testDescription)).not.toThrow() + expect(validateDescriptionLength(testDescription)).toBe(testDescription) + } + else { + expect(() => validateDescriptionLength(testDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + } + }) + }) + }) +}) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 64186a1b10..ed1c995e25 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -8,6 +8,7 @@ import Header from '@/app/components/header' import { EventEmitterContextProvider } from '@/context/event-emitter' import { ProviderContextProvider } from '@/context/provider-context' import { ModalContextProvider } from '@/context/modal-context' +import GotoAnything from '@/app/components/goto-anything' const Layout = ({ children }: { children: ReactNode }) => { return ( @@ -22,6 +23,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
{children} + diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 8250789def..41a6971bf5 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -87,7 +87,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
{ setIsShowAvatarPicker(true) }} - className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black bg-opacity-50 opacity-0 transition-opacity group-hover:opacity-100" + className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100" > diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index a024403368..4835a03ad0 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -12,7 +12,6 @@ import { RiFileUploadLine, } from '@remixicon/react' import AppIcon from '../base/app-icon' -import cn from '@/utils/classnames' import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' @@ -31,6 +30,7 @@ import Divider from '../base/divider' import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' +import cn from '@/utils/classnames' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, @@ -256,32 +256,40 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx }} className='block w-full' > -
-
- -
-
+
+
+
+ +
+ {expand && ( +
+
+ +
+
+ )} +
+ {!expand && ( +
+
-
-
-
-
{appDetail.name}
+ )} + {expand && ( +
+
+
{appDetail.name}
+
+
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
-
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
-
+ )}
)} diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index 72dd33c72e..479eedc9cf 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -32,7 +32,7 @@ const AccessControlDialog = ({ leaveFrom="opacity-100" leaveTo="opacity-0" > -
+
diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index da4a25c1d8..0fad6cc740 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() { setSelectedGroupsForBreadcrumb([]) }, [setSelectedGroupsForBreadcrumb]) return
- 0 && 'text-text-accent cursor-pointer')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} + 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} {selectedGroupsForBreadcrumb.map((group, index) => { return
/ @@ -198,7 +198,7 @@ type BaseItemProps = { children: React.ReactNode } function BaseItem({ children, className }: BaseItemProps) { - return
+ return
{children}
} diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 7ba22907dd..feb7a38165 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -4,7 +4,6 @@ import React, { useRef, useState } from 'react' import { useGetState, useInfiniteScroll } from 'ahooks' import { useTranslation } from 'react-i18next' import Link from 'next/link' -import produce from 'immer' import Modal from '@/app/components/base/modal' import type { DataSet } from '@/models/datasets' import Button from '@/app/components/base/button' @@ -29,9 +28,10 @@ const SelectDataSet: FC = ({ onSelect, }) => { const { t } = useTranslation() - const [selected, setSelected] = React.useState(selectedIds.map(id => ({ id }) as any)) + const [selected, setSelected] = React.useState([]) const [loaded, setLoaded] = React.useState(false) const [datasets, setDataSets] = React.useState(null) + const [hasInitialized, setHasInitialized] = React.useState(false) const hasNoData = !datasets || datasets?.length === 0 const canSelectMulti = true @@ -49,19 +49,17 @@ const SelectDataSet: FC = ({ const newList = [...(datasets || []), ...data.filter(item => item.indexing_technique || item.provider === 'external')] setDataSets(newList) setLoaded(true) - if (!selected.find(item => !item.name)) - return { list: [] } - const newSelected = produce(selected, (draft) => { - selected.forEach((item, index) => { - if (!item.name) { // not fetched database - const newItem = newList.find(i => i.id === item.id) - if (newItem) - draft[index] = newItem - } - }) - }) - setSelected(newSelected) + // Initialize selected datasets based on selectedIds and available datasets + if (!hasInitialized) { + if (selectedIds.length > 0) { + const validSelectedDatasets = selectedIds + .map(id => newList.find(item => item.id === id)) + .filter(Boolean) as DataSet[] + setSelected(validSelectedDatasets) + } + setHasInitialized(true) + } } return { list: [] } }, diff --git a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx index 346de078b4..85c55c5385 100644 --- a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx +++ b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx @@ -40,13 +40,13 @@ type CategoryItemProps = { } function CategoryItem({ category, active, onClick }: CategoryItemProps) { return
  • { onClick?.(category) }}> {category === AppCategories.RECOMMENDED &&
    } + className={classNames('system-sm-medium text-components-menu-item-text group-hover:text-components-menu-item-text-hover group-[.active]:text-components-menu-item-text-active', active && 'system-sm-semibold')} />
  • } diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index c37f7b051a..70a45a4bbe 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -82,8 +82,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') getRedirection(isCurrentWorkspaceEditor, app, push) } - catch { - notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + catch (e: any) { + notify({ + type: 'error', + message: e.message || t('app.newApp.appCreateFailed'), + }) } isCreatingRef.current = false }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor]) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index a91c2edf1e..688da4c25d 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -117,8 +117,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { if (onRefresh) onRefresh() } - catch { - notify({ type: 'error', message: t('app.editFailed') }) + catch (e: any) { + notify({ + type: 'error', + message: e.message || t('app.editFailed'), + }) } }, [app.id, notify, onRefresh, t]) @@ -364,7 +367,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
    {app.description} diff --git a/web/app/components/apps/footer.tsx b/web/app/components/apps/footer.tsx index c5efb2b8b4..9fed4c8757 100644 --- a/web/app/components/apps/footer.tsx +++ b/web/app/components/apps/footer.tsx @@ -1,6 +1,6 @@ -import React, { useState } from 'react' +import React from 'react' import Link from 'next/link' -import { RiCloseLine, RiDiscordFill, RiGithubFill } from '@remixicon/react' +import { RiDiscordFill, RiGithubFill } from '@remixicon/react' import { useTranslation } from 'react-i18next' type CustomLinkProps = { @@ -26,24 +26,9 @@ const CustomLink = React.memo(({ const Footer = () => { const { t } = useTranslation() - const [isVisible, setIsVisible] = useState(true) - - const handleClose = () => { - setIsVisible(false) - } - - if (!isVisible) - return null return (