diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 9807cb4e6a..43912cd75d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -381,10 +381,9 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - segment_file_map = {} valid_dataset_documents = {} - image_doc_ids = [] + image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} @@ -417,28 +416,39 @@ class RetrievalService: child_index_node_ids = [i for i in child_index_node_ids if i] index_node_ids = [i for i in index_node_ids if i] - segment_ids = [] + segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map = {} - child_chunk_map = {} - doc_segment_map = {} + attachment_map: dict[str, list[dict[str, Any]]] = {} + child_chunk_map: dict[str, list[ChildChunk]] = {} + doc_segment_map: dict[str, list[str]] = {} with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) for attachment in attachments: segment_ids.append(attachment["segment_id"]) - attachment_map[attachment["segment_id"]] = attachment - doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"] - + if attachment["segment_id"] in attachment_map: + attachment_map[attachment["segment_id"]].append(attachment["attachment_info"]) + else: + attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]] + if attachment["segment_id"] in doc_segment_map: + doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) + else: + doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() for i in child_index_nodes: segment_ids.append(i.segment_id) - child_chunk_map[i.segment_id] = i - doc_segment_map[i.segment_id] = i.index_node_id + if i.segment_id in child_chunk_map: + child_chunk_map[i.segment_id].append(i) + else: + child_chunk_map[i.segment_id] = [i] + if i.segment_id in doc_segment_map: + doc_segment_map[i.segment_id].append(i.index_node_id) + else: + doc_segment_map[i.segment_id] = [i.index_node_id] if index_node_ids: document_segment_stmt = select(DocumentSegment).where( @@ -448,7 +458,7 @@ class RetrievalService: ) index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: - doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id + doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -461,95 +471,86 @@ class RetrievalService: segments.extend(index_node_segments) for segment in segments: - doc_id = doc_segment_map.get(segment.id) - child_chunk = child_chunk_map.get(segment.id) - attachment_info = attachment_map.get(segment.id) + child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) + attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) - if doc_id: - document = doc_to_document_map[doc_id] - ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get( - document.metadata.get("document_id") - ) - - if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - if child_chunk: + if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunks or attachment_infos: + child_chunk_details = [] + max_score = 0.0 + for child_chunk in child_chunks: + document = doc_to_document_map[child_chunk.index_node_id] child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, "score": document.metadata.get("score", 0.0) if document else 0.0, } - map_detail = { - "max_score": document.metadata.get("score", 0.0) if document else 0.0, - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, + child_chunk_details.append(child_chunk_detail) + max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + for attachment_info in attachment_infos: + file_document = doc_to_document_map[attachment_info["id"]] + max_score = max( + max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 + ) + + map_detail = { + "max_score": max_score, + "child_chunks": child_chunk_details, } - if attachment_info: - segment_file_map[segment.id] = [attachment_info] - records.append(record) - else: - if child_chunk: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - if segment.id in segment_child_map: - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], - document.metadata.get("score", 0.0) if document else 0.0, - ) - else: - segment_child_map[segment.id] = { - "max_score": document.metadata.get("score", 0.0) if document else 0.0, - "child_chunks": [child_chunk_detail], - } - if attachment_info: - if segment.id in segment_file_map: - segment_file_map[segment.id].append(attachment_info) - else: - segment_file_map[segment.id] = [attachment_info] - else: - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score", 0.0), # type: ignore - } - if attachment_info: - segment_file_map[segment.id] = [attachment_info] - records.append(record) - else: - if attachment_info: - attachment_infos = segment_file_map.get(segment.id, []) - if attachment_info not in attachment_infos: - attachment_infos.append(attachment_info) - segment_file_map[segment.id] = attachment_infos + segment_child_map[segment.id] = map_detail + record: dict[str, Any] = { + "segment": segment, + } + records.append(record) + else: + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + max_score = 0.0 + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + for attachment_info in attachment_infos: + file_doc = doc_to_document_map.get(attachment_info["id"]) + if file_doc: + max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { + "segment": segment, + "score": max_score, + } + records.append(record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore - if record["segment"].id in segment_file_map: - record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] + if record["segment"].id in attachment_map: + record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] - result = [] + result: list[RetrievalSegments] = [] for record in records: # Extract segment segment = record["segment"] # Extract child_chunks, ensuring it's a list or None - child_chunks = record.get("child_chunks") - if not isinstance(child_chunks, list): - child_chunks = None + raw_child_chunks = record.get("child_chunks") + child_chunks_list: list[RetrievalChildChunk] | None = None + if isinstance(raw_child_chunks, list): + # Sort by score descending + sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True) + child_chunks_list = [ + RetrievalChildChunk( + id=chunk["id"], + content=chunk["content"], + score=chunk.get("score", 0.0), + position=chunk["position"], + ) + for chunk in sorted_chunks + ] # Extract files, ensuring it's a list or None files = record.get("files") @@ -566,11 +567,11 @@ class RetrievalService: # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks, score=score, files=files + segment=segment, child_chunks=child_chunks_list, score=score, files=files ) result.append(retrieval_segment) - return result + return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True) except Exception as e: db.session.rollback() raise e diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 445a0a7f8b..0615b8312c 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -255,7 +255,10 @@ class PGVector(BaseVector): return with self._get_cursor() as cur: - cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'") + if not cur.fetchone(): + cur.execute("CREATE EXTENSION vector") + cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) # PG hnsw index only support 2000 dimension or less # ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index baf879df95..2c3fc5ab75 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, literal, or_, select from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -1036,7 +1036,7 @@ class DatasetRetrieval: if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): - self._process_metadata_filter_func( + self.process_metadata_filter_func( sequence, filter.get("condition"), # type: ignore filter.get("metadata_name"), # type: ignore @@ -1072,7 +1072,7 @@ class DatasetRetrieval: value=expected_value, ) ) - filters = self._process_metadata_filter_func( + filters = self.process_metadata_filter_func( sequence, condition.comparison_operator, metadata_name, @@ -1168,8 +1168,9 @@ class DatasetRetrieval: return None return automatic_metadata_filters - def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + @classmethod + def process_metadata_filter_func( + cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return filters @@ -1218,6 +1219,20 @@ class DatasetRetrieval: case "≥" | ">=": filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value) + case "in" | "not in": + if isinstance(value, str): + value_list = [v.strip() for v in value.split(",") if v.strip()] + elif isinstance(value, (list, tuple)): + value_list = [str(v) for v in value if v is not None] + else: + value_list = [str(value)] if value is not None else [] + + if not value_list: + # `field in []` is False, `field not in []` is True + filters.append(literal(condition == "not in")) + else: + op = json_field.in_ if condition == "in" else json_field.notin_ + filters.append(op(value_list)) case _: pass diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index adc474bd60..8670a71aa3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import and_, func, literal, or_, select +from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import DatasetRetrieveConfigEntity @@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): - self._process_metadata_filter_func( + DatasetRetrieval.process_metadata_filter_func( sequence, filter.get("condition", ""), filter.get("metadata_name", ""), @@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD value=expected_value, ) ) - filters = self._process_metadata_filter_func( + filters = DatasetRetrieval.process_metadata_filter_func( sequence, condition.comparison_operator, metadata_name, @@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return [], usage return automatic_metadata_filters, usage - def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] - ) -> list[Any]: - if value is None and condition not in ("empty", "not empty"): - return filters - - json_field = Document.doc_metadata[metadata_name].as_string() - - match condition: - case "contains": - filters.append(json_field.like(f"%{value}%")) - - case "not contains": - filters.append(json_field.notlike(f"%{value}%")) - - case "start with": - filters.append(json_field.like(f"{value}%")) - - case "end with": - filters.append(json_field.like(f"%{value}")) - case "in": - if isinstance(value, str): - value_list = [v.strip() for v in value.split(",") if v.strip()] - elif isinstance(value, (list, tuple)): - value_list = [str(v) for v in value if v is not None] - else: - value_list = [str(value)] if value is not None else [] - - if not value_list: - filters.append(literal(False)) - else: - filters.append(json_field.in_(value_list)) - - case "not in": - if isinstance(value, str): - value_list = [v.strip() for v in value.split(",") if v.strip()] - elif isinstance(value, (list, tuple)): - value_list = [str(v) for v in value if v is not None] - else: - value_list = [str(value)] if value is not None else [] - - if not value_list: - filters.append(literal(True)) - else: - filters.append(json_field.notin_(value_list)) - - case "is" | "=": - if isinstance(value, str): - filters.append(json_field == value) - elif isinstance(value, (int, float)): - filters.append(Document.doc_metadata[metadata_name].as_float() == value) - - case "is not" | "≠": - if isinstance(value, str): - filters.append(json_field != value) - elif isinstance(value, (int, float)): - filters.append(Document.doc_metadata[metadata_name].as_float() != value) - - case "empty": - filters.append(Document.doc_metadata[metadata_name].is_(None)) - - case "not empty": - filters.append(Document.doc_metadata[metadata_name].isnot(None)) - - case "before" | "<": - filters.append(Document.doc_metadata[metadata_name].as_float() < value) - - case "after" | ">": - filters.append(Document.doc_metadata[metadata_name].as_float() > value) - - case "≤" | "<=": - filters.append(Document.doc_metadata[metadata_name].as_float() <= value) - - case "≥" | ">=": - filters.append(Document.doc_metadata[metadata_name].as_float() >= value) - - case _: - pass - - return filters - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 83d0fcf296..c0cc0e5233 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -110,5 +110,5 @@ class EnterpriseService: if not app_id: raise ValueError("app_id must be provided.") - body = {"appId": app_id} - EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) + params = {"appId": app_id} + EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py new file mode 100644 index 0000000000..4998a9858f --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -0,0 +1,327 @@ +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.pgvector.pgvector import ( + PGVector, + PGVectorConfig, +) + + +class TestPGVector(unittest.TestCase): + def setUp(self): + self.config = PGVectorConfig( + host="localhost", + port=5432, + user="test_user", + password="test_password", + database="test_db", + min_connection=1, + max_connection=5, + pg_bigm=False, + ) + self.collection_name = "test_collection" + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + def test_init(self, mock_pool_class): + """Test PGVector initialization.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + pgvector = PGVector(self.collection_name, self.config) + + assert pgvector._collection_name == self.collection_name + assert pgvector.table_name == f"embedding_{self.collection_name}" + assert pgvector.get_type() == "pgvector" + assert pgvector.pool is not None + assert pgvector.pg_bigm is False + assert pgvector.index_hash is not None + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + def test_init_with_pg_bigm(self, mock_pool_class): + """Test PGVector initialization with pg_bigm enabled.""" + config = PGVectorConfig( + host="localhost", + port=5432, + user="test_user", + password="test_password", + database="test_db", + min_connection=1, + max_connection=5, + pg_bigm=True, + ) + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + pgvector = PGVector(self.collection_name, config) + + assert pgvector.pg_bigm is True + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_basic(self, mock_redis, mock_pool_class): + """Test basic collection creation.""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Verify SQL execution calls + assert mock_cursor.execute.called + + # Check that CREATE TABLE was called with correct dimension + create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)] + assert len(create_table_calls) == 1 + assert "vector(1536)" in create_table_calls[0][0][0] + + # Check that CREATE INDEX was called (dimension <= 2000) + create_index_calls = [ + call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call) + ] + assert len(create_index_calls) == 1 + + # Verify Redis cache was set + mock_redis.set.assert_called_once() + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class): + """Test collection creation with dimension > 2000 (no HNSW index).""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(3072) # Dimension > 2000 + + # Check that CREATE TABLE was called + create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)] + assert len(create_table_calls) == 1 + assert "vector(3072)" in create_table_calls[0][0][0] + + # Check that HNSW index was NOT created (dimension > 2000) + hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)] + assert len(hnsw_index_calls) == 0 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class): + """Test collection creation with pg_bigm enabled.""" + config = PGVectorConfig( + host="localhost", + port=5432, + user="test_user", + password="test_password", + database="test_db", + min_connection=1, + max_connection=5, + pg_bigm=True, + ) + + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, config) + pgvector._create_collection(1536) + + # Check that pg_bigm index was created + bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)] + assert len(bigm_index_calls) == 1 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class): + """Test that vector extension is created if it doesn't exist.""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + # First call: vector extension doesn't exist + mock_cursor.fetchone.return_value = None + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Check that CREATE EXTENSION was called + create_extension_calls = [ + call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call) + ] + assert len(create_extension_calls) == 1 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class): + """Test that collection creation is skipped when cache exists.""" + # Mock Redis operations - cache exists + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 # Cache exists + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Check that no SQL was executed (early return due to cache) + assert mock_cursor.execute.call_count == 0 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class): + """Test that Redis lock is used during collection creation.""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Verify Redis lock was acquired with correct lock name + mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20) + + # Verify lock context manager was entered and exited + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + def test_get_cursor_context_manager(self, mock_pool_class): + """Test that _get_cursor properly manages connection lifecycle.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + pgvector = PGVector(self.collection_name, self.config) + + with pgvector._get_cursor() as cur: + assert cur == mock_cursor + + # Verify connection lifecycle methods were called + mock_pool.getconn.assert_called_once() + mock_cursor.close.assert_called_once() + mock_conn.commit.assert_called_once() + mock_pool.putconn.assert_called_once_with(mock_conn) + + +@pytest.mark.parametrize( + "invalid_config_override", + [ + {"host": ""}, # Test empty host + {"port": 0}, # Test invalid port + {"user": ""}, # Test empty user + {"password": ""}, # Test empty password + {"database": ""}, # Test empty database + {"min_connection": 0}, # Test invalid min_connection + {"max_connection": 0}, # Test invalid max_connection + {"min_connection": 10, "max_connection": 5}, # Test min > max + ], +) +def test_config_validation_parametrized(invalid_config_override): + """Test configuration validation for various invalid inputs using parametrize.""" + config = { + "host": "localhost", + "port": 5432, + "user": "test_user", + "password": "test_password", + "database": "test_db", + "min_connection": 1, + "max_connection": 5, + } + config.update(invalid_config_override) + + with pytest.raises(ValueError): + PGVectorConfig(**config) + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py new file mode 100644 index 0000000000..07d6e51e4b --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py @@ -0,0 +1,873 @@ +""" +Unit tests for DatasetRetrieval.process_metadata_filter_func. + +This module provides comprehensive test coverage for the process_metadata_filter_func +method in the DatasetRetrieval class, which is responsible for building SQLAlchemy +filter expressions based on metadata filtering conditions. + +Conditions Tested: +================== +1. **String Conditions**: contains, not contains, start with, end with +2. **Equality Conditions**: is / =, is not / ≠ +3. **Null Conditions**: empty, not empty +4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >= +5. **List Conditions**: in +6. **Edge Cases**: None values, different data types (str, int, float) + +Test Architecture: +================== +- Direct instantiation of DatasetRetrieval +- Mocking of DatasetDocument model attributes +- Verification of SQLAlchemy filter expressions +- Follows Arrange-Act-Assert (AAA) pattern + +Running Tests: +============== + # Run all tests in this module + uv run --project api pytest \ + api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v + + # Run a specific test + uv run --project api pytest \ + api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\ +TestProcessMetadataFilterFunc::test_contains_condition -v +""" + +from unittest.mock import MagicMock + +import pytest + +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval + + +class TestProcessMetadataFilterFunc: + """ + Comprehensive test suite for process_metadata_filter_func method. + + This test class validates all metadata filtering conditions supported by + the DatasetRetrieval class, including string operations, numeric comparisons, + null checks, and list operations. + + Method Signature: + ================== + def process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + ) -> list: + + The method builds SQLAlchemy filter expressions by: + 1. Validating value is not None (except for empty/not empty conditions) + 2. Using DatasetDocument.doc_metadata JSON field operations + 3. Adding appropriate SQLAlchemy expressions to the filters list + 4. Returning the updated filters list + + Mocking Strategy: + ================== + - Mock DatasetDocument.doc_metadata to avoid database dependencies + - Verify filter expressions are created correctly + - Test with various data types (str, int, float, list) + """ + + @pytest.fixture + def retrieval(self): + """ + Create a DatasetRetrieval instance for testing. + + Returns: + DatasetRetrieval: Instance to test process_metadata_filter_func + """ + return DatasetRetrieval() + + @pytest.fixture + def mock_doc_metadata(self): + """ + Mock the DatasetDocument.doc_metadata JSON field. + + The method uses DatasetDocument.doc_metadata[metadata_name] to access + JSON fields. We mock this to avoid database dependencies. + + Returns: + Mock: Mocked doc_metadata attribute + """ + mock_metadata_field = MagicMock() + + # Create mock for string access + mock_string_access = MagicMock() + mock_string_access.like = MagicMock() + mock_string_access.notlike = MagicMock() + mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_string_access.in_ = MagicMock(return_value=MagicMock()) + + # Create mock for float access (for numeric comparisons) + mock_float_access = MagicMock() + mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__le__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) + + # Create mock for null checks + mock_null_access = MagicMock() + mock_null_access.is_ = MagicMock(return_value=MagicMock()) + mock_null_access.isnot = MagicMock(return_value=MagicMock()) + + # Setup __getitem__ to return appropriate mock based on usage + def getitem_side_effect(name): + if name in ["author", "title", "category"]: + return mock_string_access + elif name in ["year", "price", "rating"]: + return mock_float_access + else: + return mock_string_access + + mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) + mock_metadata_field.as_string.return_value = mock_string_access + mock_metadata_field.as_float.return_value = mock_float_access + mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ + mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot + + return mock_metadata_field + + # ==================== String Condition Tests ==================== + + def test_contains_condition_string_value(self, retrieval): + """ + Test 'contains' condition with string value. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value% syntax + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "John" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_contains_condition(self, retrieval): + """ + Test 'not contains' condition. + + Verifies: + - Filters list is populated with NOT LIKE expression + - Pattern matching uses %value% syntax with negation + """ + filters = [] + sequence = 0 + condition = "not contains" + metadata_name = "title" + value = "banned" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_start_with_condition(self, retrieval): + """ + Test 'start with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses value% syntax + """ + filters = [] + sequence = 0 + condition = "start with" + metadata_name = "category" + value = "tech" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_end_with_condition(self, retrieval): + """ + Test 'end with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value syntax + """ + filters = [] + sequence = 0 + condition = "end with" + metadata_name = "filename" + value = ".pdf" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Equality Condition Tests ==================== + + def test_is_condition_with_string_value(self, retrieval): + """ + Test 'is' (=) condition with string value. + + Verifies: + - Filters list is populated with equality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = "Jane Doe" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_equals_condition_with_string_value(self, retrieval): + """ + Test '=' condition with string value. + + Verifies: + - Same behavior as 'is' condition + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "=" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_int_value(self, retrieval): + """ + Test 'is' condition with integer value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_float_value(self, retrieval): + """ + Test 'is' condition with float value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "price" + value = 19.99 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_string_value(self, retrieval): + """ + Test 'is not' (≠) condition with string value. + + Verifies: + - Filters list is populated with inequality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "author" + value = "Unknown" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_equals_condition(self, retrieval): + """ + Test '≠' condition with string value. + + Verifies: + - Same behavior as 'is not' condition + - Inequality expression is used + """ + filters = [] + sequence = 0 + condition = "≠" + metadata_name = "category" + value = "archived" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_numeric_value(self, retrieval): + """ + Test 'is not' condition with numeric value. + + Verifies: + - Numeric inequality comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Null Condition Tests ==================== + + def test_empty_condition(self, retrieval): + """ + Test 'empty' condition (null check). + + Verifies: + - Filters list is populated with IS NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "empty" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_empty_condition(self, retrieval): + """ + Test 'not empty' condition (not null check). + + Verifies: + - Filters list is populated with IS NOT NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "not empty" + metadata_name = "description" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Numeric Comparison Tests ==================== + + def test_before_condition(self, retrieval): + """ + Test 'before' (<) condition. + + Verifies: + - Filters list is populated with less than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "before" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_condition(self, retrieval): + """ + Test '<' condition. + + Verifies: + - Same behavior as 'before' condition + - Less than expression is used + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "price" + value = 100.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_after_condition(self, retrieval): + """ + Test 'after' (>) condition. + + Verifies: + - Filters list is populated with greater than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "after" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_condition(self, retrieval): + """ + Test '>' condition. + + Verifies: + - Same behavior as 'after' condition + - Greater than expression is used + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≤' condition. + + Verifies: + - Filters list is populated with less than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≤" + metadata_name = "price" + value = 50.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_ascii(self, retrieval): + """ + Test '<=' condition. + + Verifies: + - Same behavior as '≤' condition + - Less than or equal expression is used + """ + filters = [] + sequence = 0 + condition = "<=" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≥' condition. + + Verifies: + - Filters list is populated with greater than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≥" + metadata_name = "rating" + value = 3.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_ascii(self, retrieval): + """ + Test '>=' condition. + + Verifies: + - Same behavior as '≥' condition + - Greater than or equal expression is used + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== List/In Condition Tests ==================== + + def test_in_condition_with_comma_separated_string(self, retrieval): + """ + Test 'in' condition with comma-separated string value. + + Verifies: + - String is split into list + - Whitespace is trimmed from each value + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "tech, science, AI " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_list_value(self, retrieval): + """ + Test 'in' condition with list value. + + Verifies: + - List is processed correctly + - None values are filtered out + - IN expression is created with valid values + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "tags" + value = ["python", "javascript", None, "golang"] + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_tuple_value(self, retrieval): + """ + Test 'in' condition with tuple value. + + Verifies: + - Tuple is processed like a list + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = ("tech", "science", "ai") + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_empty_string(self, retrieval): + """ + Test 'in' condition with empty string value. + + Verifies: + - Empty string results in literal(False) filter + - No valid values to match + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + # Verify it's a literal(False) expression + # This is a bit tricky to test without access to the actual expression + + def test_in_condition_with_only_whitespace(self, retrieval): + """ + Test 'in' condition with whitespace-only string value. + + Verifies: + - Whitespace-only string results in literal(False) filter + - All values are stripped and filtered out + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = " , , " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_single_string(self, retrieval): + """ + Test 'in' condition with single non-comma string. + + Verifies: + - Single string is treated as single-item list + - IN expression is created with one value + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Edge Case Tests ==================== + + def test_none_value_with_non_empty_condition(self, retrieval): + """ + Test None value with conditions that require value. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values (except empty/not empty) + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 # No filter added + + def test_none_value_with_equals_condition(self, retrieval): + """ + Test None value with 'is' (=) condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_none_value_with_numeric_condition(self, retrieval): + """ + Test None value with numeric comparison condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "year" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_existing_filters_preserved(self, retrieval): + """ + Test that existing filters are preserved. + + Verifies: + - Existing filters in the list are not removed + - New filters are appended to the list + """ + existing_filter = MagicMock() + filters = [existing_filter] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 2 + assert filters[0] == existing_filter + + def test_multiple_filters_accumulated(self, retrieval): + """ + Test multiple calls to accumulate filters. + + Verifies: + - Each call adds a new filter to the list + - All filters are preserved across calls + """ + filters = [] + + # First filter + retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) + assert len(filters) == 1 + + # Second filter + retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) + assert len(filters) == 2 + + # Third filter + retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) + assert len(filters) == 3 + + def test_unknown_condition(self, retrieval): + """ + Test unknown/unsupported condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for unknown conditions + """ + filters = [] + sequence = 0 + condition = "unknown_condition" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_empty_string_value_with_contains(self, retrieval): + """ + Test empty string value with 'contains' condition. + + Verifies: + - Filter is added even with empty string + - LIKE expression is created + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_special_characters_in_value(self, retrieval): + """ + Test special characters in value string. + + Verifies: + - Special characters are handled in value + - LIKE expression is created correctly + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "title" + value = "C++ & Python's features" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_zero_value_with_numeric_condition(self, retrieval): + """ + Test zero value with numeric comparison condition. + + Verifies: + - Zero is treated as valid value + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "price" + value = 0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_negative_value_with_numeric_condition(self, retrieval): + """ + Test negative value with numeric comparison condition. + + Verifies: + - Negative numbers are handled correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "temperature" + value = -10.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_float_value_with_integer_comparison(self, retrieval): + """ + Test float value with numeric comparison condition. + + Verifies: + - Float values work correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx new file mode 100644 index 0000000000..e7c3d4a3c9 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx @@ -0,0 +1,141 @@ +import type { DataSet } from '@/models/datasets' +import { act, fireEvent, render, screen } from '@testing-library/react' +import * as React from 'react' + +import { describe, expect, it, vi } from 'vitest' +import { IndexingType } from '@/app/components/datasets/create/step-two' +import { DatasetPermission } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import SelectDataSet from './index' + +vi.mock('@/i18n-config/i18next-config', () => ({ + __esModule: true, + default: { + changeLanguage: vi.fn(), + addResourceBundle: vi.fn(), + use: vi.fn().mockReturnThis(), + init: vi.fn(), + addResource: vi.fn(), + hasResourceBundle: vi.fn().mockReturnValue(true), + }, +})) +const mockUseInfiniteScroll = vi.fn() +vi.mock('ahooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...(typeof actual === 'object' && actual !== null ? actual : {}), + useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args), + } +}) + +const mockUseInfiniteDatasets = vi.fn() +vi.mock('@/service/knowledge/use-dataset', () => ({ + useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args), +})) + +vi.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`, + }), +})) + +const baseProps = { + isShow: true, + onClose: vi.fn(), + selectedIds: [] as string[], + onSelect: vi.fn(), +} + +const makeDataset = (overrides: Partial): DataSet => ({ + id: 'dataset-id', + name: 'Dataset Name', + provider: 'internal', + icon_info: { + icon_type: 'emoji', + icon: '💾', + icon_background: '#fff', + icon_url: '', + }, + embedding_available: true, + is_multimodal: false, + description: '', + permission: DatasetPermission.allTeamMembers, + indexing_technique: IndexingType.ECONOMICAL, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.fullText, + top_k: 5, + reranking_enable: false, + reranking_model: { + reranking_model_name: '', + reranking_provider_name: '', + }, + score_threshold_enabled: false, + score_threshold: 0, + }, + ...overrides, +} as DataSet) + +describe('SelectDataSet', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders dataset entries, allows selection, and fires onSelect', async () => { + const datasetOne = makeDataset({ + id: 'set-1', + name: 'Dataset One', + is_multimodal: true, + indexing_technique: IndexingType.ECONOMICAL, + }) + const datasetTwo = makeDataset({ + id: 'set-2', + name: 'Hidden Dataset', + embedding_available: false, + provider: 'external', + }) + mockUseInfiniteDatasets.mockReturnValue({ + data: { pages: [{ data: [datasetOne, datasetTwo] }] }, + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + hasNextPage: false, + }) + + const onSelect = vi.fn() + await act(async () => { + render() + }) + + expect(screen.getByText('Dataset One')).toBeInTheDocument() + expect(screen.getByText('Hidden Dataset')).toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByText('Dataset One')) + }) + expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument() + + const addButton = screen.getByRole('button', { name: 'common.operation.add' }) + await act(async () => { + fireEvent.click(addButton) + }) + expect(onSelect).toHaveBeenCalledWith([datasetOne]) + }) + + it('shows empty state when no datasets are available and disables add', async () => { + mockUseInfiniteDatasets.mockReturnValue({ + data: { pages: [{ data: [] }] }, + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + hasNextPage: false, + }) + + await act(async () => { + render() + }) + + expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument() + expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create') + expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled() + }) +}) diff --git a/web/app/components/app/configuration/prompt-value-panel/index.spec.tsx b/web/app/components/app/configuration/prompt-value-panel/index.spec.tsx new file mode 100644 index 0000000000..039ed078d7 --- /dev/null +++ b/web/app/components/app/configuration/prompt-value-panel/index.spec.tsx @@ -0,0 +1,125 @@ +import type { IPromptValuePanelProps } from './index' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useStore } from '@/app/components/app/store' +import ConfigContext from '@/context/debug-configuration' +import { AppModeEnum, ModelModeType, Resolution } from '@/types/app' +import PromptValuePanel from './index' + +vi.mock('@/app/components/app/store', () => ({ + useStore: vi.fn(), +})) +vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({ + __esModule: true, + default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => ( + + ), +})) + +const mockSetShowAppConfigureFeaturesModal = vi.fn() +const mockUseStore = vi.mocked(useStore) +const mockSetInputs = vi.fn() +const mockOnSend = vi.fn() + +const promptVariables = [ + { key: 'textVar', name: 'Text Var', type: 'string', required: true }, + { key: 'boolVar', name: 'Boolean Var', type: 'checkbox' }, +] as const + +const baseContextValue: any = { + modelModeType: ModelModeType.completion, + modelConfig: { + configs: { + prompt_template: 'prompt template', + prompt_variables: promptVariables, + }, + }, + setInputs: mockSetInputs, + mode: AppModeEnum.COMPLETION, + isAdvancedMode: false, + completionPromptConfig: { + prompt: { text: 'completion' }, + conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' }, + }, + chatPromptConfig: { prompt: [] }, +} as any + +const defaultProps: IPromptValuePanelProps = { + appType: AppModeEnum.COMPLETION, + onSend: mockOnSend, + inputs: { textVar: 'initial', boolVar: false }, + visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] }, + onVisionFilesChange: vi.fn(), +} + +const renderPanel = (options: { + context?: Partial + props?: Partial +} = {}) => { + const contextValue = { ...baseContextValue, ...options.context } + const props = { ...defaultProps, ...options.props } + return render( + + + , + ) +} + +describe('PromptValuePanel', () => { + beforeEach(() => { + mockUseStore.mockImplementation(selector => selector({ + setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal, + appSidebarExpand: '', + currentLogModalActiveTab: 'prompt', + showPromptLogModal: false, + showAgentLogModal: false, + setShowPromptLogModal: vi.fn(), + setShowAgentLogModal: vi.fn(), + showMessageLogModal: false, + showAppConfigureFeaturesModal: false, + } as any)) + mockSetInputs.mockClear() + mockOnSend.mockClear() + mockSetShowAppConfigureFeaturesModal.mockClear() + }) + + it('updates inputs, clears values, and triggers run when ready', async () => { + renderPanel() + + const textInput = screen.getByPlaceholderText('Text Var') + fireEvent.change(textInput, { target: { value: 'updated' } }) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' })) + + const clearButton = screen.getByRole('button', { name: 'common.operation.clear' }) + fireEvent.click(clearButton) + + expect(mockSetInputs).toHaveBeenLastCalledWith({ + textVar: '', + boolVar: '', + }) + + const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' }) + expect(runButton).not.toBeDisabled() + fireEvent.click(runButton) + await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1)) + }) + + it('disables run when mode is not completion', () => { + renderPanel({ + context: { + mode: AppModeEnum.CHAT, + }, + props: { + appType: AppModeEnum.CHAT, + }, + }) + + const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' }) + expect(runButton).toBeDisabled() + fireEvent.click(runButton) + expect(mockOnSend).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/app/configuration/prompt-value-panel/utils.spec.ts b/web/app/components/app/configuration/prompt-value-panel/utils.spec.ts new file mode 100644 index 0000000000..7a7e0da9a9 --- /dev/null +++ b/web/app/components/app/configuration/prompt-value-panel/utils.spec.ts @@ -0,0 +1,29 @@ +import type { PromptVariable } from '@/models/debug' + +import { describe, expect, it } from 'vitest' +import { replaceStringWithValues } from './utils' + +const promptVariables: PromptVariable[] = [ + { key: 'user', name: 'User', type: 'string' }, + { key: 'topic', name: 'Topic', type: 'string' }, +] + +describe('replaceStringWithValues', () => { + it('should replace placeholders when inputs have values', () => { + const template = 'Hello {{user}} talking about {{topic}}' + const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' }) + expect(result).toBe('Hello Alice talking about cats') + }) + + it('should use prompt variable name when value is missing', () => { + const template = 'Hi {{user}} from {{topic}}' + const result = replaceStringWithValues(template, promptVariables, {}) + expect(result).toBe('Hi {{User}} from {{Topic}}') + }) + + it('should leave placeholder untouched when no variable is defined', () => { + const template = 'Unknown {{missing}} placeholder' + const result = replaceStringWithValues(template, promptVariables, {}) + expect(result).toBe('Unknown {{missing}} placeholder') + }) +}) diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx new file mode 100644 index 0000000000..02c00ed3fd --- /dev/null +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -0,0 +1,162 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { useRouter } from 'next/navigation' +import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' +import { trackEvent } from '@/app/components/base/amplitude' + +import { ToastContext } from '@/app/components/base/toast' +import { NEED_REFRESH_APP_LIST_KEY } from '@/config' +import { useAppContext } from '@/context/app-context' +import { useProviderContext } from '@/context/provider-context' +import { createApp } from '@/service/apps' +import { AppModeEnum } from '@/types/app' +import { getRedirection } from '@/utils/app-redirection' +import CreateAppModal from './index' + +vi.mock('ahooks', () => ({ + useDebounceFn: (fn: (...args: any[]) => any) => { + const run = (...args: any[]) => fn(...args) + const cancel = vi.fn() + const flush = vi.fn() + return { run, cancel, flush } + }, + useKeyPress: vi.fn(), + useHover: () => false, +})) +vi.mock('next/navigation', () => ({ + useRouter: vi.fn(), +})) +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: vi.fn(), +})) +vi.mock('@/service/apps', () => ({ + createApp: vi.fn(), +})) +vi.mock('@/utils/app-redirection', () => ({ + getRedirection: vi.fn(), +})) +vi.mock('@/context/provider-context', () => ({ + useProviderContext: vi.fn(), +})) +vi.mock('@/context/app-context', () => ({ + useAppContext: vi.fn(), +})) +vi.mock('@/context/i18n', () => ({ + useDocLink: () => () => '/guides', +})) +vi.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => ({ theme: 'light' }), +})) + +const mockNotify = vi.fn() +const mockUseRouter = vi.mocked(useRouter) +const mockPush = vi.fn() +const mockCreateApp = vi.mocked(createApp) +const mockTrackEvent = vi.mocked(trackEvent) +const mockGetRedirection = vi.mocked(getRedirection) +const mockUseProviderContext = vi.mocked(useProviderContext) +const mockUseAppContext = vi.mocked(useAppContext) + +const defaultPlanUsage = { + buildApps: 0, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, +} + +const renderModal = () => { + const onClose = vi.fn() + const onSuccess = vi.fn() + render( + + + , + ) + return { onClose, onSuccess } +} + +describe('CreateAppModal', () => { + const mockSetItem = vi.fn() + const originalLocalStorage = window.localStorage + + beforeEach(() => { + vi.clearAllMocks() + mockUseRouter.mockReturnValue({ push: mockPush } as any) + mockUseProviderContext.mockReturnValue({ + plan: { + type: AppModeEnum.ADVANCED_CHAT, + usage: defaultPlanUsage, + total: { ...defaultPlanUsage, buildApps: 1 }, + reset: {}, + }, + enableBilling: true, + } as any) + mockUseAppContext.mockReturnValue({ + isCurrentWorkspaceEditor: true, + } as any) + mockSetItem.mockClear() + Object.defineProperty(window, 'localStorage', { + value: { + setItem: mockSetItem, + getItem: vi.fn(), + removeItem: vi.fn(), + clear: vi.fn(), + key: vi.fn(), + length: 0, + }, + writable: true, + }) + }) + + afterAll(() => { + Object.defineProperty(window, 'localStorage', { + value: originalLocalStorage, + writable: true, + }) + }) + + it('creates an app, notifies success, and fires callbacks', async () => { + const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } + mockCreateApp.mockResolvedValue(mockApp as any) + const { onClose, onSuccess } = renderModal() + + const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') + fireEvent.change(nameInput, { target: { value: 'My App' } }) + fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + + await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({ + name: 'My App', + description: '', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#FFEAD5', + mode: AppModeEnum.ADVANCED_CHAT, + })) + + expect(mockTrackEvent).toHaveBeenCalledWith('create_app', { + app_mode: AppModeEnum.ADVANCED_CHAT, + description: '', + }) + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(onSuccess).toHaveBeenCalled() + expect(onClose).toHaveBeenCalled() + await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1')) + await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush)) + }) + + it('shows error toast when creation fails', async () => { + mockCreateApp.mockRejectedValue(new Error('boom')) + const { onClose } = renderModal() + + const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') + fireEvent.change(nameInput, { target: { value: 'My App' } }) + fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + + await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) + expect(onClose).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/app/overview/embedded/index.spec.tsx b/web/app/components/app/overview/embedded/index.spec.tsx new file mode 100644 index 0000000000..36f2e980c4 --- /dev/null +++ b/web/app/components/app/overview/embedded/index.spec.tsx @@ -0,0 +1,121 @@ +import type { SiteInfo } from '@/models/share' +import { fireEvent, render, screen } from '@testing-library/react' +import copy from 'copy-to-clipboard' +import * as React from 'react' + +import { act } from 'react' +import { afterAll, afterEach, describe, expect, it, vi } from 'vitest' +import Embedded from './index' + +vi.mock('./style.module.css', () => ({ + __esModule: true, + default: { + option: 'option', + active: 'active', + iframeIcon: 'iframeIcon', + scriptsIcon: 'scriptsIcon', + chromePluginIcon: 'chromePluginIcon', + pluginInstallIcon: 'pluginInstallIcon', + }, +})) +const mockThemeBuilder = { + buildTheme: vi.fn(), + theme: { + primaryColor: '#123456', + }, +} +const mockUseAppContext = vi.fn(() => ({ + langGeniusVersionInfo: { + current_env: 'PRODUCTION', + current_version: '', + latest_version: '', + release_date: '', + release_notes: '', + version: '', + can_auto_update: false, + }, +})) + +vi.mock('copy-to-clipboard', () => ({ + __esModule: true, + default: vi.fn(), +})) +vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({ + useThemeContext: () => mockThemeBuilder, +})) +vi.mock('@/context/app-context', () => ({ + useAppContext: () => mockUseAppContext(), +})) +const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) +const mockedCopy = vi.mocked(copy) + +const siteInfo: SiteInfo = { + title: 'test site', + chat_color_theme: '#000000', + chat_color_theme_inverted: false, +} + +const baseProps = { + isShow: true, + siteInfo, + onClose: vi.fn(), + appBaseUrl: 'https://app.example.com', + accessToken: 'token', + className: 'custom-modal', +} + +const getCopyButton = () => { + const buttons = screen.getAllByRole('button') + const actionButton = buttons.find(button => button.className.includes('action-btn')) + expect(actionButton).toBeDefined() + return actionButton! +} + +describe('Embedded', () => { + afterEach(() => { + vi.clearAllMocks() + mockWindowOpen.mockClear() + }) + + afterAll(() => { + mockWindowOpen.mockRestore() + }) + + it('builds theme and copies iframe snippet', async () => { + await act(async () => { + render() + }) + + const actionButton = getCopyButton() + const innerDiv = actionButton.querySelector('div') + act(() => { + fireEvent.click(innerDiv ?? actionButton) + }) + + expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted) + expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token')) + }) + + it('opens chrome plugin store link when chrome option selected', async () => { + await act(async () => { + render() + }) + + const optionButtons = document.body.querySelectorAll('[class*="option"]') + expect(optionButtons.length).toBeGreaterThanOrEqual(3) + act(() => { + fireEvent.click(optionButtons[2]) + }) + + const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin') + act(() => { + fireEvent.click(chromeText) + }) + + expect(mockWindowOpen).toHaveBeenCalledWith( + 'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf', + '_blank', + 'noopener,noreferrer', + ) + }) +}) diff --git a/web/app/components/app/text-generate/saved-items/index.spec.tsx b/web/app/components/app/text-generate/saved-items/index.spec.tsx new file mode 100644 index 0000000000..b83c812c19 --- /dev/null +++ b/web/app/components/app/text-generate/saved-items/index.spec.tsx @@ -0,0 +1,67 @@ +import type { ISavedItemsProps } from './index' +import { fireEvent, render, screen } from '@testing-library/react' +import copy from 'copy-to-clipboard' + +import * as React from 'react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import Toast from '@/app/components/base/toast' +import SavedItems from './index' + +vi.mock('copy-to-clipboard', () => ({ + __esModule: true, + default: vi.fn(), +})) +vi.mock('next/navigation', () => ({ + useParams: () => ({}), + usePathname: () => '/', +})) + +const mockCopy = vi.mocked(copy) +const toastNotifySpy = vi.spyOn(Toast, 'notify') + +const baseProps: ISavedItemsProps = { + list: [ + { id: '1', answer: 'hello world' }, + ], + isShowTextToSpeech: true, + onRemove: vi.fn(), + onStartCreateContent: vi.fn(), +} + +describe('SavedItems', () => { + beforeEach(() => { + vi.clearAllMocks() + toastNotifySpy.mockClear() + }) + + it('renders saved answers with metadata and controls', () => { + const { container } = render() + + const markdownElement = container.querySelector('.markdown-body') + expect(markdownElement).toBeInTheDocument() + expect(screen.getByText('11 common.unit.char')).toBeInTheDocument() + + const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]') + const actionButtons = actionArea?.querySelectorAll('button') ?? [] + expect(actionButtons.length).toBeGreaterThanOrEqual(3) + }) + + it('copies content and notifies, and triggers remove callback', () => { + const handleRemove = vi.fn() + const { container } = render() + + const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]') + const actionButtons = actionArea?.querySelectorAll('button') ?? [] + expect(actionButtons.length).toBeGreaterThanOrEqual(3) + + const copyButton = actionButtons[1] + const deleteButton = actionButtons[2] + + fireEvent.click(copyButton) + expect(mockCopy).toHaveBeenCalledWith('hello world') + expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' }) + + fireEvent.click(deleteButton) + expect(handleRemove).toHaveBeenCalledWith('1') + }) +}) diff --git a/web/app/components/app/text-generate/saved-items/no-data/index.spec.tsx b/web/app/components/app/text-generate/saved-items/no-data/index.spec.tsx new file mode 100644 index 0000000000..59b950054c --- /dev/null +++ b/web/app/components/app/text-generate/saved-items/no-data/index.spec.tsx @@ -0,0 +1,22 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' + +import NoData from './index' + +describe('NoData', () => { + it('renders title/description and calls callback when button clicked', () => { + const handleStart = vi.fn() + render() + + const title = screen.getByText('share.generation.savedNoData.title') + const description = screen.getByText('share.generation.savedNoData.description') + const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' }) + + expect(title).toBeInTheDocument() + expect(description).toBeInTheDocument() + expect(button).toBeInTheDocument() + + fireEvent.click(button) + expect(handleStart).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/custom/custom-web-app-brand/index.spec.tsx b/web/app/components/custom/custom-web-app-brand/index.spec.tsx new file mode 100644 index 0000000000..e50ca4e9b2 --- /dev/null +++ b/web/app/components/custom/custom-web-app-brand/index.spec.tsx @@ -0,0 +1,147 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' +import { useToastContext } from '@/app/components/base/toast' +import { Plan } from '@/app/components/billing/type' +import { useAppContext } from '@/context/app-context' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { useProviderContext } from '@/context/provider-context' +import { updateCurrentWorkspace } from '@/service/common' +import CustomWebAppBrand from './index' + +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: vi.fn(), +})) +vi.mock('@/service/common', () => ({ + updateCurrentWorkspace: vi.fn(), +})) +vi.mock('@/context/app-context', () => ({ + useAppContext: vi.fn(), +})) +vi.mock('@/context/provider-context', () => ({ + useProviderContext: vi.fn(), +})) +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn(), +})) +vi.mock('@/app/components/base/image-uploader/utils', () => ({ + imageUpload: vi.fn(), + getImageUploadErrorMessage: vi.fn(), +})) + +const mockNotify = vi.fn() +const mockUseToastContext = vi.mocked(useToastContext) +const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace) +const mockUseAppContext = vi.mocked(useAppContext) +const mockUseProviderContext = vi.mocked(useProviderContext) +const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) +const mockImageUpload = vi.mocked(imageUpload) +const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage) + +const defaultPlanUsage = { + buildApps: 0, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, +} + +const renderComponent = () => render() + +describe('CustomWebAppBrand', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseToastContext.mockReturnValue({ notify: mockNotify } as any) + mockUpdateCurrentWorkspace.mockResolvedValue({} as any) + mockUseAppContext.mockReturnValue({ + currentWorkspace: { + custom_config: { + replace_webapp_logo: 'https://example.com/replace.png', + remove_webapp_brand: false, + }, + }, + mutateCurrentWorkspace: vi.fn(), + isCurrentWorkspaceManager: true, + } as any) + mockUseProviderContext.mockReturnValue({ + plan: { + type: Plan.professional, + usage: defaultPlanUsage, + total: defaultPlanUsage, + reset: {}, + }, + enableBilling: false, + } as any) + const systemFeaturesState = { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + }, + } + mockUseGlobalPublicStore.mockImplementation(selector => selector ? selector({ systemFeatures: systemFeaturesState } as any) : { systemFeatures: systemFeaturesState }) + mockGetImageUploadErrorMessage.mockReturnValue('upload error') + }) + + it('disables upload controls when the user cannot manage the workspace', () => { + mockUseAppContext.mockReturnValue({ + currentWorkspace: { + custom_config: { + replace_webapp_logo: '', + remove_webapp_brand: false, + }, + }, + mutateCurrentWorkspace: vi.fn(), + isCurrentWorkspaceManager: false, + } as any) + + const { container } = renderComponent() + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + expect(fileInput).toBeDisabled() + }) + + it('toggles remove brand switch and calls the backend + mutate', async () => { + const mutateMock = vi.fn() + mockUseAppContext.mockReturnValue({ + currentWorkspace: { + custom_config: { + replace_webapp_logo: '', + remove_webapp_brand: false, + }, + }, + mutateCurrentWorkspace: mutateMock, + isCurrentWorkspaceManager: true, + } as any) + + renderComponent() + const switchInput = screen.getByRole('switch') + fireEvent.click(switchInput) + + await waitFor(() => expect(mockUpdateCurrentWorkspace).toHaveBeenCalledWith({ + url: '/workspaces/custom-config', + body: { remove_webapp_brand: true }, + })) + await waitFor(() => expect(mutateMock).toHaveBeenCalled()) + }) + + it('shows cancel/apply buttons after successful upload and cancels properly', async () => { + mockImageUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }) => { + onProgressCallback(50) + onSuccessCallback({ id: 'new-logo' }) + }) + + const { container } = renderComponent() + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + const testFile = new File(['content'], 'logo.png', { type: 'image/png' }) + fireEvent.change(fileInput, { target: { files: [testFile] } }) + + await waitFor(() => expect(mockImageUpload).toHaveBeenCalled()) + await waitFor(() => screen.getByRole('button', { name: 'custom.apply' })) + + const cancelButton = screen.getByRole('button', { name: 'common.operation.cancel' }) + fireEvent.click(cancelButton) + + await waitFor(() => expect(screen.queryByRole('button', { name: 'custom.apply' })).toBeNull()) + }) +})