mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into pydantic-remaining
This commit is contained in:
commit
ef6d73c3f4
|
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import RetrievalSegments
|
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
from core.rag.entities.metadata_entities import MetadataCondition
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
|
|
@ -381,10 +381,9 @@ class RetrievalService:
|
||||||
records = []
|
records = []
|
||||||
include_segment_ids = set()
|
include_segment_ids = set()
|
||||||
segment_child_map = {}
|
segment_child_map = {}
|
||||||
segment_file_map = {}
|
|
||||||
|
|
||||||
valid_dataset_documents = {}
|
valid_dataset_documents = {}
|
||||||
image_doc_ids = []
|
image_doc_ids: list[Any] = []
|
||||||
child_index_node_ids = []
|
child_index_node_ids = []
|
||||||
index_node_ids = []
|
index_node_ids = []
|
||||||
doc_to_document_map = {}
|
doc_to_document_map = {}
|
||||||
|
|
@ -417,28 +416,39 @@ class RetrievalService:
|
||||||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||||
index_node_ids = [i for i in index_node_ids if i]
|
index_node_ids = [i for i in index_node_ids if i]
|
||||||
|
|
||||||
segment_ids = []
|
segment_ids: list[str] = []
|
||||||
index_node_segments: list[DocumentSegment] = []
|
index_node_segments: list[DocumentSegment] = []
|
||||||
segments: list[DocumentSegment] = []
|
segments: list[DocumentSegment] = []
|
||||||
attachment_map = {}
|
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||||
child_chunk_map = {}
|
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||||
doc_segment_map = {}
|
doc_segment_map: dict[str, list[str]] = {}
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||||
|
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
segment_ids.append(attachment["segment_id"])
|
segment_ids.append(attachment["segment_id"])
|
||||||
attachment_map[attachment["segment_id"]] = attachment
|
if attachment["segment_id"] in attachment_map:
|
||||||
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
|
||||||
|
else:
|
||||||
|
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
|
||||||
|
if attachment["segment_id"] in doc_segment_map:
|
||||||
|
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||||
|
else:
|
||||||
|
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||||
|
|
||||||
for i in child_index_nodes:
|
for i in child_index_nodes:
|
||||||
segment_ids.append(i.segment_id)
|
segment_ids.append(i.segment_id)
|
||||||
child_chunk_map[i.segment_id] = i
|
if i.segment_id in child_chunk_map:
|
||||||
doc_segment_map[i.segment_id] = i.index_node_id
|
child_chunk_map[i.segment_id].append(i)
|
||||||
|
else:
|
||||||
|
child_chunk_map[i.segment_id] = [i]
|
||||||
|
if i.segment_id in doc_segment_map:
|
||||||
|
doc_segment_map[i.segment_id].append(i.index_node_id)
|
||||||
|
else:
|
||||||
|
doc_segment_map[i.segment_id] = [i.index_node_id]
|
||||||
|
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
|
|
@ -448,7 +458,7 @@ class RetrievalService:
|
||||||
)
|
)
|
||||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||||
for index_node_segment in index_node_segments:
|
for index_node_segment in index_node_segments:
|
||||||
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||||
if segment_ids:
|
if segment_ids:
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
DocumentSegment.enabled == True,
|
DocumentSegment.enabled == True,
|
||||||
|
|
@ -461,95 +471,86 @@ class RetrievalService:
|
||||||
segments.extend(index_node_segments)
|
segments.extend(index_node_segments)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
doc_id = doc_segment_map.get(segment.id)
|
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||||
child_chunk = child_chunk_map.get(segment.id)
|
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||||
attachment_info = attachment_map.get(segment.id)
|
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||||
|
|
||||||
if doc_id:
|
|
||||||
document = doc_to_document_map[doc_id]
|
|
||||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
|
||||||
document.metadata.get("document_id")
|
|
||||||
)
|
|
||||||
|
|
||||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
if child_chunk:
|
if child_chunks or attachment_infos:
|
||||||
|
child_chunk_details = []
|
||||||
|
max_score = 0.0
|
||||||
|
for child_chunk in child_chunks:
|
||||||
|
document = doc_to_document_map[child_chunk.index_node_id]
|
||||||
child_chunk_detail = {
|
child_chunk_detail = {
|
||||||
"id": child_chunk.id,
|
"id": child_chunk.id,
|
||||||
"content": child_chunk.content,
|
"content": child_chunk.content,
|
||||||
"position": child_chunk.position,
|
"position": child_chunk.position,
|
||||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
}
|
}
|
||||||
|
child_chunk_details.append(child_chunk_detail)
|
||||||
|
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||||
|
for attachment_info in attachment_infos:
|
||||||
|
file_document = doc_to_document_map[attachment_info["id"]]
|
||||||
|
max_score = max(
|
||||||
|
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
map_detail = {
|
map_detail = {
|
||||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
"max_score": max_score,
|
||||||
"child_chunks": [child_chunk_detail],
|
"child_chunks": child_chunk_details,
|
||||||
}
|
}
|
||||||
segment_child_map[segment.id] = map_detail
|
segment_child_map[segment.id] = map_detail
|
||||||
record = {
|
record: dict[str, Any] = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
}
|
}
|
||||||
if attachment_info:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
else:
|
|
||||||
if child_chunk:
|
|
||||||
child_chunk_detail = {
|
|
||||||
"id": child_chunk.id,
|
|
||||||
"content": child_chunk.content,
|
|
||||||
"position": child_chunk.position,
|
|
||||||
"score": document.metadata.get("score", 0.0),
|
|
||||||
}
|
|
||||||
if segment.id in segment_child_map:
|
|
||||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
|
||||||
segment_child_map[segment.id]["max_score"] = max(
|
|
||||||
segment_child_map[segment.id]["max_score"],
|
|
||||||
document.metadata.get("score", 0.0) if document else 0.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
segment_child_map[segment.id] = {
|
|
||||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
|
||||||
"child_chunks": [child_chunk_detail],
|
|
||||||
}
|
|
||||||
if attachment_info:
|
|
||||||
if segment.id in segment_file_map:
|
|
||||||
segment_file_map[segment.id].append(attachment_info)
|
|
||||||
else:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
else:
|
else:
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
|
max_score = 0.0
|
||||||
|
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||||
|
if segment_document:
|
||||||
|
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||||
|
for attachment_info in attachment_infos:
|
||||||
|
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||||
|
if file_doc:
|
||||||
|
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": document.metadata.get("score", 0.0), # type: ignore
|
"score": max_score,
|
||||||
}
|
}
|
||||||
if attachment_info:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
else:
|
|
||||||
if attachment_info:
|
|
||||||
attachment_infos = segment_file_map.get(segment.id, [])
|
|
||||||
if attachment_info not in attachment_infos:
|
|
||||||
attachment_infos.append(attachment_info)
|
|
||||||
segment_file_map[segment.id] = attachment_infos
|
|
||||||
|
|
||||||
# Add child chunks information to records
|
# Add child chunks information to records
|
||||||
for record in records:
|
for record in records:
|
||||||
if record["segment"].id in segment_child_map:
|
if record["segment"].id in segment_child_map:
|
||||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||||
if record["segment"].id in segment_file_map:
|
if record["segment"].id in attachment_map:
|
||||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||||
|
|
||||||
result = []
|
result: list[RetrievalSegments] = []
|
||||||
for record in records:
|
for record in records:
|
||||||
# Extract segment
|
# Extract segment
|
||||||
segment = record["segment"]
|
segment = record["segment"]
|
||||||
|
|
||||||
# Extract child_chunks, ensuring it's a list or None
|
# Extract child_chunks, ensuring it's a list or None
|
||||||
child_chunks = record.get("child_chunks")
|
raw_child_chunks = record.get("child_chunks")
|
||||||
if not isinstance(child_chunks, list):
|
child_chunks_list: list[RetrievalChildChunk] | None = None
|
||||||
child_chunks = None
|
if isinstance(raw_child_chunks, list):
|
||||||
|
# Sort by score descending
|
||||||
|
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||||
|
child_chunks_list = [
|
||||||
|
RetrievalChildChunk(
|
||||||
|
id=chunk["id"],
|
||||||
|
content=chunk["content"],
|
||||||
|
score=chunk.get("score", 0.0),
|
||||||
|
position=chunk["position"],
|
||||||
|
)
|
||||||
|
for chunk in sorted_chunks
|
||||||
|
]
|
||||||
|
|
||||||
# Extract files, ensuring it's a list or None
|
# Extract files, ensuring it's a list or None
|
||||||
files = record.get("files")
|
files = record.get("files")
|
||||||
|
|
@ -566,11 +567,11 @@ class RetrievalService:
|
||||||
|
|
||||||
# Create RetrievalSegments object
|
# Create RetrievalSegments object
|
||||||
retrieval_segment = RetrievalSegments(
|
retrieval_segment = RetrievalSegments(
|
||||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||||
)
|
)
|
||||||
result.append(retrieval_segment)
|
result.append(retrieval_segment)
|
||||||
|
|
||||||
return result
|
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,10 @@ class PGVector(BaseVector):
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
if not cur.fetchone():
|
||||||
|
cur.execute("CREATE EXTENSION vector")
|
||||||
|
|
||||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
# PG hnsw index only support 2000 dimension or less
|
# PG hnsw index only support 2000 dimension or less
|
||||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import and_, or_, select
|
from sqlalchemy import and_, literal, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
|
|
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
self._process_metadata_filter_func(
|
self.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
filter.get("condition"), # type: ignore
|
filter.get("condition"), # type: ignore
|
||||||
filter.get("metadata_name"), # type: ignore
|
filter.get("metadata_name"), # type: ignore
|
||||||
|
|
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
|
||||||
value=expected_value,
|
value=expected_value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
filters = self._process_metadata_filter_func(
|
filters = self.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
condition.comparison_operator,
|
condition.comparison_operator,
|
||||||
metadata_name,
|
metadata_name,
|
||||||
|
|
@ -1168,8 +1168,9 @@ class DatasetRetrieval:
|
||||||
return None
|
return None
|
||||||
return automatic_metadata_filters
|
return automatic_metadata_filters
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
@classmethod
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
def process_metadata_filter_func(
|
||||||
|
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||||
):
|
):
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
if value is None and condition not in ("empty", "not empty"):
|
||||||
return filters
|
return filters
|
||||||
|
|
@ -1218,6 +1219,20 @@ class DatasetRetrieval:
|
||||||
|
|
||||||
case "≥" | ">=":
|
case "≥" | ">=":
|
||||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||||
|
case "in" | "not in":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
value_list = [str(v) for v in value if v is not None]
|
||||||
|
else:
|
||||||
|
value_list = [str(value)] if value is not None else []
|
||||||
|
|
||||||
|
if not value_list:
|
||||||
|
# `field in []` is False, `field not in []` is True
|
||||||
|
filters.append(literal(condition == "not in"))
|
||||||
|
else:
|
||||||
|
op = json_field.in_ if condition == "in" else json_field.notin_
|
||||||
|
filters.append(op(value_list))
|
||||||
case _:
|
case _:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from sqlalchemy import and_, func, literal, or_, select
|
from sqlalchemy import and_, func, or_, select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
|
|
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
self._process_metadata_filter_func(
|
DatasetRetrieval.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
filter.get("condition", ""),
|
filter.get("condition", ""),
|
||||||
filter.get("metadata_name", ""),
|
filter.get("metadata_name", ""),
|
||||||
|
|
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
value=expected_value,
|
value=expected_value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
filters = self._process_metadata_filter_func(
|
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||||
sequence,
|
sequence,
|
||||||
condition.comparison_operator,
|
condition.comparison_operator,
|
||||||
metadata_name,
|
metadata_name,
|
||||||
|
|
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
return [], usage
|
return [], usage
|
||||||
return automatic_metadata_filters, usage
|
return automatic_metadata_filters, usage
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
|
||||||
) -> list[Any]:
|
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
|
||||||
return filters
|
|
||||||
|
|
||||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
|
||||||
|
|
||||||
match condition:
|
|
||||||
case "contains":
|
|
||||||
filters.append(json_field.like(f"%{value}%"))
|
|
||||||
|
|
||||||
case "not contains":
|
|
||||||
filters.append(json_field.notlike(f"%{value}%"))
|
|
||||||
|
|
||||||
case "start with":
|
|
||||||
filters.append(json_field.like(f"{value}%"))
|
|
||||||
|
|
||||||
case "end with":
|
|
||||||
filters.append(json_field.like(f"%{value}"))
|
|
||||||
case "in":
|
|
||||||
if isinstance(value, str):
|
|
||||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value_list = [str(v) for v in value if v is not None]
|
|
||||||
else:
|
|
||||||
value_list = [str(value)] if value is not None else []
|
|
||||||
|
|
||||||
if not value_list:
|
|
||||||
filters.append(literal(False))
|
|
||||||
else:
|
|
||||||
filters.append(json_field.in_(value_list))
|
|
||||||
|
|
||||||
case "not in":
|
|
||||||
if isinstance(value, str):
|
|
||||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value_list = [str(v) for v in value if v is not None]
|
|
||||||
else:
|
|
||||||
value_list = [str(value)] if value is not None else []
|
|
||||||
|
|
||||||
if not value_list:
|
|
||||||
filters.append(literal(True))
|
|
||||||
else:
|
|
||||||
filters.append(json_field.notin_(value_list))
|
|
||||||
|
|
||||||
case "is" | "=":
|
|
||||||
if isinstance(value, str):
|
|
||||||
filters.append(json_field == value)
|
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
|
||||||
|
|
||||||
case "is not" | "≠":
|
|
||||||
if isinstance(value, str):
|
|
||||||
filters.append(json_field != value)
|
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
|
||||||
|
|
||||||
case "empty":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
|
||||||
|
|
||||||
case "not empty":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
|
||||||
|
|
||||||
case "before" | "<":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
|
||||||
|
|
||||||
case "after" | ">":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
|
||||||
|
|
||||||
case "≤" | "<=":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
|
||||||
|
|
||||||
case "≥" | ">=":
|
|
||||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return filters
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
|
|
@ -110,5 +110,5 @@ class EnterpriseService:
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("app_id must be provided.")
|
raise ValueError("app_id must be provided.")
|
||||||
|
|
||||||
body = {"appId": app_id}
|
params = {"appId": app_id}
|
||||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,327 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||||
|
PGVector,
|
||||||
|
PGVectorConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPGVector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.config = PGVectorConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=5432,
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
database="test_db",
|
||||||
|
min_connection=1,
|
||||||
|
max_connection=5,
|
||||||
|
pg_bigm=False,
|
||||||
|
)
|
||||||
|
self.collection_name = "test_collection"
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
def test_init(self, mock_pool_class):
|
||||||
|
"""Test PGVector initialization."""
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
|
||||||
|
assert pgvector._collection_name == self.collection_name
|
||||||
|
assert pgvector.table_name == f"embedding_{self.collection_name}"
|
||||||
|
assert pgvector.get_type() == "pgvector"
|
||||||
|
assert pgvector.pool is not None
|
||||||
|
assert pgvector.pg_bigm is False
|
||||||
|
assert pgvector.index_hash is not None
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
def test_init_with_pg_bigm(self, mock_pool_class):
|
||||||
|
"""Test PGVector initialization with pg_bigm enabled."""
|
||||||
|
config = PGVectorConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=5432,
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
database="test_db",
|
||||||
|
min_connection=1,
|
||||||
|
max_connection=5,
|
||||||
|
pg_bigm=True,
|
||||||
|
)
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, config)
|
||||||
|
|
||||||
|
assert pgvector.pg_bigm is True
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||||
|
def test_create_collection_basic(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test basic collection creation."""
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock()
|
||||||
|
mock_lock.__exit__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value = mock_lock
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
pgvector._create_collection(1536)
|
||||||
|
|
||||||
|
# Verify SQL execution calls
|
||||||
|
assert mock_cursor.execute.called
|
||||||
|
|
||||||
|
# Check that CREATE TABLE was called with correct dimension
|
||||||
|
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||||
|
assert len(create_table_calls) == 1
|
||||||
|
assert "vector(1536)" in create_table_calls[0][0][0]
|
||||||
|
|
||||||
|
# Check that CREATE INDEX was called (dimension <= 2000)
|
||||||
|
create_index_calls = [
|
||||||
|
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
|
||||||
|
]
|
||||||
|
assert len(create_index_calls) == 1
|
||||||
|
|
||||||
|
# Verify Redis cache was set
|
||||||
|
mock_redis.set.assert_called_once()
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||||
|
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test collection creation with dimension > 2000 (no HNSW index)."""
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock()
|
||||||
|
mock_lock.__exit__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value = mock_lock
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
pgvector._create_collection(3072) # Dimension > 2000
|
||||||
|
|
||||||
|
# Check that CREATE TABLE was called
|
||||||
|
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||||
|
assert len(create_table_calls) == 1
|
||||||
|
assert "vector(3072)" in create_table_calls[0][0][0]
|
||||||
|
|
||||||
|
# Check that HNSW index was NOT created (dimension > 2000)
|
||||||
|
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
|
||||||
|
assert len(hnsw_index_calls) == 0
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||||
|
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test collection creation with pg_bigm enabled."""
|
||||||
|
config = PGVectorConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=5432,
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
database="test_db",
|
||||||
|
min_connection=1,
|
||||||
|
max_connection=5,
|
||||||
|
pg_bigm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock()
|
||||||
|
mock_lock.__exit__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value = mock_lock
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, config)
|
||||||
|
pgvector._create_collection(1536)
|
||||||
|
|
||||||
|
# Check that pg_bigm index was created
|
||||||
|
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
|
||||||
|
assert len(bigm_index_calls) == 1
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||||
|
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test that vector extension is created if it doesn't exist."""
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock()
|
||||||
|
mock_lock.__exit__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value = mock_lock
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
# First call: vector extension doesn't exist
|
||||||
|
mock_cursor.fetchone.return_value = None
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
pgvector._create_collection(1536)
|
||||||
|
|
||||||
|
# Check that CREATE EXTENSION was called
|
||||||
|
create_extension_calls = [
|
||||||
|
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
|
||||||
|
]
|
||||||
|
assert len(create_extension_calls) == 1
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||||
|
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test that collection creation is skipped when cache exists."""
|
||||||
|
# Mock Redis operations - cache exists
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock()
|
||||||
|
mock_lock.__exit__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value = mock_lock
|
||||||
|
mock_redis.get.return_value = 1 # Cache exists
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
pgvector._create_collection(1536)
|
||||||
|
|
||||||
|
# Check that no SQL was executed (early return due to cache)
|
||||||
|
assert mock_cursor.execute.call_count == 0
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||||
|
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
|
||||||
|
"""Test that Redis lock is used during collection creation."""
|
||||||
|
# Mock Redis operations
|
||||||
|
mock_lock = MagicMock()
|
||||||
|
mock_lock.__enter__ = MagicMock()
|
||||||
|
mock_lock.__exit__ = MagicMock()
|
||||||
|
mock_redis.lock.return_value = mock_lock
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.set.return_value = None
|
||||||
|
|
||||||
|
# Mock the connection pool
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock connection and cursor
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
pgvector._create_collection(1536)
|
||||||
|
|
||||||
|
# Verify Redis lock was acquired with correct lock name
|
||||||
|
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
|
||||||
|
|
||||||
|
# Verify lock context manager was entered and exited
|
||||||
|
mock_lock.__enter__.assert_called_once()
|
||||||
|
mock_lock.__exit__.assert_called_once()
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||||
|
def test_get_cursor_context_manager(self, mock_pool_class):
|
||||||
|
"""Test that _get_cursor properly manages connection lifecycle."""
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool_class.return_value = mock_pool
|
||||||
|
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_cursor = MagicMock()
|
||||||
|
mock_pool.getconn.return_value = mock_conn
|
||||||
|
mock_conn.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
|
pgvector = PGVector(self.collection_name, self.config)
|
||||||
|
|
||||||
|
with pgvector._get_cursor() as cur:
|
||||||
|
assert cur == mock_cursor
|
||||||
|
|
||||||
|
# Verify connection lifecycle methods were called
|
||||||
|
mock_pool.getconn.assert_called_once()
|
||||||
|
mock_cursor.close.assert_called_once()
|
||||||
|
mock_conn.commit.assert_called_once()
|
||||||
|
mock_pool.putconn.assert_called_once_with(mock_conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_config_override",
|
||||||
|
[
|
||||||
|
{"host": ""}, # Test empty host
|
||||||
|
{"port": 0}, # Test invalid port
|
||||||
|
{"user": ""}, # Test empty user
|
||||||
|
{"password": ""}, # Test empty password
|
||||||
|
{"database": ""}, # Test empty database
|
||||||
|
{"min_connection": 0}, # Test invalid min_connection
|
||||||
|
{"max_connection": 0}, # Test invalid max_connection
|
||||||
|
{"min_connection": 10, "max_connection": 5}, # Test min > max
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_config_validation_parametrized(invalid_config_override):
|
||||||
|
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||||
|
config = {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 5432,
|
||||||
|
"user": "test_user",
|
||||||
|
"password": "test_password",
|
||||||
|
"database": "test_db",
|
||||||
|
"min_connection": 1,
|
||||||
|
"max_connection": 5,
|
||||||
|
}
|
||||||
|
config.update(invalid_config_override)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PGVectorConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,873 @@
|
||||||
|
"""
|
||||||
|
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||||
|
|
||||||
|
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||||
|
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||||
|
filter expressions based on metadata filtering conditions.
|
||||||
|
|
||||||
|
Conditions Tested:
|
||||||
|
==================
|
||||||
|
1. **String Conditions**: contains, not contains, start with, end with
|
||||||
|
2. **Equality Conditions**: is / =, is not / ≠
|
||||||
|
3. **Null Conditions**: empty, not empty
|
||||||
|
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||||
|
5. **List Conditions**: in
|
||||||
|
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||||
|
|
||||||
|
Test Architecture:
|
||||||
|
==================
|
||||||
|
- Direct instantiation of DatasetRetrieval
|
||||||
|
- Mocking of DatasetDocument model attributes
|
||||||
|
- Verification of SQLAlchemy filter expressions
|
||||||
|
- Follows Arrange-Act-Assert (AAA) pattern
|
||||||
|
|
||||||
|
Running Tests:
|
||||||
|
==============
|
||||||
|
# Run all tests in this module
|
||||||
|
uv run --project api pytest \
|
||||||
|
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||||
|
|
||||||
|
# Run a specific test
|
||||||
|
uv run --project api pytest \
|
||||||
|
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||||
|
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessMetadataFilterFunc:
|
||||||
|
"""
|
||||||
|
Comprehensive test suite for process_metadata_filter_func method.
|
||||||
|
|
||||||
|
This test class validates all metadata filtering conditions supported by
|
||||||
|
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||||
|
null checks, and list operations.
|
||||||
|
|
||||||
|
Method Signature:
|
||||||
|
==================
|
||||||
|
def process_metadata_filter_func(
|
||||||
|
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||||
|
) -> list:
|
||||||
|
|
||||||
|
The method builds SQLAlchemy filter expressions by:
|
||||||
|
1. Validating value is not None (except for empty/not empty conditions)
|
||||||
|
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||||
|
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||||
|
4. Returning the updated filters list
|
||||||
|
|
||||||
|
Mocking Strategy:
|
||||||
|
==================
|
||||||
|
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||||
|
- Verify filter expressions are created correctly
|
||||||
|
- Test with various data types (str, int, float, list)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def retrieval(self):
|
||||||
|
"""
|
||||||
|
Create a DatasetRetrieval instance for testing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||||
|
"""
|
||||||
|
return DatasetRetrieval()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_doc_metadata(self):
|
||||||
|
"""
|
||||||
|
Mock the DatasetDocument.doc_metadata JSON field.
|
||||||
|
|
||||||
|
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||||
|
JSON fields. We mock this to avoid database dependencies.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mock: Mocked doc_metadata attribute
|
||||||
|
"""
|
||||||
|
mock_metadata_field = MagicMock()
|
||||||
|
|
||||||
|
# Create mock for string access
|
||||||
|
mock_string_access = MagicMock()
|
||||||
|
mock_string_access.like = MagicMock()
|
||||||
|
mock_string_access.notlike = MagicMock()
|
||||||
|
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Create mock for float access (for numeric comparisons)
|
||||||
|
mock_float_access = MagicMock()
|
||||||
|
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Create mock for null checks
|
||||||
|
mock_null_access = MagicMock()
|
||||||
|
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||||
|
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Setup __getitem__ to return appropriate mock based on usage
|
||||||
|
def getitem_side_effect(name):
|
||||||
|
if name in ["author", "title", "category"]:
|
||||||
|
return mock_string_access
|
||||||
|
elif name in ["year", "price", "rating"]:
|
||||||
|
return mock_float_access
|
||||||
|
else:
|
||||||
|
return mock_string_access
|
||||||
|
|
||||||
|
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||||
|
mock_metadata_field.as_string.return_value = mock_string_access
|
||||||
|
mock_metadata_field.as_float.return_value = mock_float_access
|
||||||
|
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||||
|
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||||
|
|
||||||
|
return mock_metadata_field
|
||||||
|
|
||||||
|
# ==================== String Condition Tests ====================
|
||||||
|
|
||||||
|
def test_contains_condition_string_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'contains' condition with string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with LIKE expression
|
||||||
|
- Pattern matching uses %value% syntax
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "contains"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = "John"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_not_contains_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'not contains' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with NOT LIKE expression
|
||||||
|
- Pattern matching uses %value% syntax with negation
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "not contains"
|
||||||
|
metadata_name = "title"
|
||||||
|
value = "banned"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_start_with_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'start with' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with LIKE expression
|
||||||
|
- Pattern matching uses value% syntax
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "start with"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = "tech"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_end_with_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'end with' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with LIKE expression
|
||||||
|
- Pattern matching uses %value syntax
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "end with"
|
||||||
|
metadata_name = "filename"
|
||||||
|
value = ".pdf"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
# ==================== Equality Condition Tests ====================
|
||||||
|
|
||||||
|
def test_is_condition_with_string_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'is' (=) condition with string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with equality expression
|
||||||
|
- String comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "is"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = "Jane Doe"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_equals_condition_with_string_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '=' condition with string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Same behavior as 'is' condition
|
||||||
|
- String comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "="
|
||||||
|
metadata_name = "category"
|
||||||
|
value = "technology"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_is_condition_with_int_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'is' condition with integer value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Numeric comparison is used
|
||||||
|
- as_float() is called on the metadata field
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "is"
|
||||||
|
metadata_name = "year"
|
||||||
|
value = 2023
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_is_condition_with_float_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'is' condition with float value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Numeric comparison is used
|
||||||
|
- as_float() is called on the metadata field
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "is"
|
||||||
|
metadata_name = "price"
|
||||||
|
value = 19.99
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_is_not_condition_with_string_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'is not' (≠) condition with string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with inequality expression
|
||||||
|
- String comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "is not"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = "Unknown"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_not_equals_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '≠' condition with string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Same behavior as 'is not' condition
|
||||||
|
- Inequality expression is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "≠"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = "archived"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'is not' condition with numeric value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Numeric inequality comparison is used
|
||||||
|
- as_float() is called on the metadata field
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "is not"
|
||||||
|
metadata_name = "year"
|
||||||
|
value = 2000
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
# ==================== Null Condition Tests ====================
|
||||||
|
|
||||||
|
def test_empty_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'empty' condition (null check).
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with IS NULL expression
|
||||||
|
- Value can be None for this condition
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "empty"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = None
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_not_empty_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'not empty' condition (not null check).
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with IS NOT NULL expression
|
||||||
|
- Value can be None for this condition
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "not empty"
|
||||||
|
metadata_name = "description"
|
||||||
|
value = None
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
# ==================== Numeric Comparison Tests ====================
|
||||||
|
|
||||||
|
def test_before_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'before' (<) condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with less than expression
|
||||||
|
- Numeric comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "before"
|
||||||
|
metadata_name = "year"
|
||||||
|
value = 2020
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_less_than_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '<' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Same behavior as 'before' condition
|
||||||
|
- Less than expression is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "<"
|
||||||
|
metadata_name = "price"
|
||||||
|
value = 100.0
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_after_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'after' (>) condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with greater than expression
|
||||||
|
- Numeric comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "after"
|
||||||
|
metadata_name = "year"
|
||||||
|
value = 2020
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_greater_than_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '>' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Same behavior as 'after' condition
|
||||||
|
- Greater than expression is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = ">"
|
||||||
|
metadata_name = "rating"
|
||||||
|
value = 4.5
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '≤' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with less than or equal expression
|
||||||
|
- Numeric comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "≤"
|
||||||
|
metadata_name = "price"
|
||||||
|
value = 50.0
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '<=' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Same behavior as '≤' condition
|
||||||
|
- Less than or equal expression is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "<="
|
||||||
|
metadata_name = "year"
|
||||||
|
value = 2023
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '≥' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filters list is populated with greater than or equal expression
|
||||||
|
- Numeric comparison is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "≥"
|
||||||
|
metadata_name = "rating"
|
||||||
|
value = 3.5
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test '>=' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Same behavior as '≥' condition
|
||||||
|
- Greater than or equal expression is used
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = ">="
|
||||||
|
metadata_name = "year"
|
||||||
|
value = 2000
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
# ==================== List/In Condition Tests ====================
|
||||||
|
|
||||||
|
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'in' condition with comma-separated string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- String is split into list
|
||||||
|
- Whitespace is trimmed from each value
|
||||||
|
- IN expression is created
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "in"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = "tech, science, AI "
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_in_condition_with_list_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'in' condition with list value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- List is processed correctly
|
||||||
|
- None values are filtered out
|
||||||
|
- IN expression is created with valid values
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "in"
|
||||||
|
metadata_name = "tags"
|
||||||
|
value = ["python", "javascript", None, "golang"]
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_in_condition_with_tuple_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'in' condition with tuple value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Tuple is processed like a list
|
||||||
|
- IN expression is created
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "in"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = ("tech", "science", "ai")
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_in_condition_with_empty_string(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'in' condition with empty string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Empty string results in literal(False) filter
|
||||||
|
- No valid values to match
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "in"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = ""
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
# Verify it's a literal(False) expression
|
||||||
|
# This is a bit tricky to test without access to the actual expression
|
||||||
|
|
||||||
|
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'in' condition with whitespace-only string value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Whitespace-only string results in literal(False) filter
|
||||||
|
- All values are stripped and filtered out
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "in"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = " , , "
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_in_condition_with_single_string(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test 'in' condition with single non-comma string.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Single string is treated as single-item list
|
||||||
|
- IN expression is created with one value
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "in"
|
||||||
|
metadata_name = "category"
|
||||||
|
value = "technology"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
# ==================== Edge Case Tests ====================
|
||||||
|
|
||||||
|
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test None value with conditions that require value.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Original filters list is returned unchanged
|
||||||
|
- No filter is added for None values (except empty/not empty)
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "contains"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = None
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 0 # No filter added
|
||||||
|
|
||||||
|
def test_none_value_with_equals_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test None value with 'is' (=) condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Original filters list is returned unchanged
|
||||||
|
- No filter is added for None values
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "is"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = None
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 0
|
||||||
|
|
||||||
|
def test_none_value_with_numeric_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test None value with numeric comparison condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Original filters list is returned unchanged
|
||||||
|
- No filter is added for None values
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = ">"
|
||||||
|
metadata_name = "year"
|
||||||
|
value = None
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 0
|
||||||
|
|
||||||
|
def test_existing_filters_preserved(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test that existing filters are preserved.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Existing filters in the list are not removed
|
||||||
|
- New filters are appended to the list
|
||||||
|
"""
|
||||||
|
existing_filter = MagicMock()
|
||||||
|
filters = [existing_filter]
|
||||||
|
sequence = 0
|
||||||
|
condition = "contains"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = "test"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 2
|
||||||
|
assert filters[0] == existing_filter
|
||||||
|
|
||||||
|
def test_multiple_filters_accumulated(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test multiple calls to accumulate filters.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Each call adds a new filter to the list
|
||||||
|
- All filters are preserved across calls
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
|
||||||
|
# First filter
|
||||||
|
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
# Second filter
|
||||||
|
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||||
|
assert len(filters) == 2
|
||||||
|
|
||||||
|
# Third filter
|
||||||
|
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||||
|
assert len(filters) == 3
|
||||||
|
|
||||||
|
def test_unknown_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test unknown/unsupported condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Original filters list is returned unchanged
|
||||||
|
- No filter is added for unknown conditions
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "unknown_condition"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = "test"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 0
|
||||||
|
|
||||||
|
def test_empty_string_value_with_contains(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test empty string value with 'contains' condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Filter is added even with empty string
|
||||||
|
- LIKE expression is created
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "contains"
|
||||||
|
metadata_name = "author"
|
||||||
|
value = ""
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_special_characters_in_value(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test special characters in value string.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Special characters are handled in value
|
||||||
|
- LIKE expression is created correctly
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "contains"
|
||||||
|
metadata_name = "title"
|
||||||
|
value = "C++ & Python's features"
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test zero value with numeric comparison condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Zero is treated as valid value
|
||||||
|
- Numeric comparison is performed
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = ">"
|
||||||
|
metadata_name = "price"
|
||||||
|
value = 0
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test negative value with numeric comparison condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Negative numbers are handled correctly
|
||||||
|
- Numeric comparison is performed
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = "<"
|
||||||
|
metadata_name = "temperature"
|
||||||
|
value = -10.5
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
||||||
|
def test_float_value_with_integer_comparison(self, retrieval):
|
||||||
|
"""
|
||||||
|
Test float value with numeric comparison condition.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Float values work correctly
|
||||||
|
- Numeric comparison is performed
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
sequence = 0
|
||||||
|
condition = ">="
|
||||||
|
metadata_name = "rating"
|
||||||
|
value = 4.5
|
||||||
|
|
||||||
|
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||||
|
|
||||||
|
assert result == filters
|
||||||
|
assert len(filters) == 1
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
import type { DataSet } from '@/models/datasets'
|
||||||
|
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||||
|
import * as React from 'react'
|
||||||
|
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||||
|
import { DatasetPermission } from '@/models/datasets'
|
||||||
|
import { RETRIEVE_METHOD } from '@/types/app'
|
||||||
|
import SelectDataSet from './index'
|
||||||
|
|
||||||
|
vi.mock('@/i18n-config/i18next-config', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
default: {
|
||||||
|
changeLanguage: vi.fn(),
|
||||||
|
addResourceBundle: vi.fn(),
|
||||||
|
use: vi.fn().mockReturnThis(),
|
||||||
|
init: vi.fn(),
|
||||||
|
addResource: vi.fn(),
|
||||||
|
hasResourceBundle: vi.fn().mockReturnValue(true),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
const mockUseInfiniteScroll = vi.fn()
|
||||||
|
vi.mock('ahooks', async (importOriginal) => {
|
||||||
|
const actual = await importOriginal()
|
||||||
|
return {
|
||||||
|
...(typeof actual === 'object' && actual !== null ? actual : {}),
|
||||||
|
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const mockUseInfiniteDatasets = vi.fn()
|
||||||
|
vi.mock('@/service/knowledge/use-dataset', () => ({
|
||||||
|
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/hooks/use-knowledge', () => ({
|
||||||
|
useKnowledge: () => ({
|
||||||
|
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const baseProps = {
|
||||||
|
isShow: true,
|
||||||
|
onClose: vi.fn(),
|
||||||
|
selectedIds: [] as string[],
|
||||||
|
onSelect: vi.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
|
||||||
|
id: 'dataset-id',
|
||||||
|
name: 'Dataset Name',
|
||||||
|
provider: 'internal',
|
||||||
|
icon_info: {
|
||||||
|
icon_type: 'emoji',
|
||||||
|
icon: '💾',
|
||||||
|
icon_background: '#fff',
|
||||||
|
icon_url: '',
|
||||||
|
},
|
||||||
|
embedding_available: true,
|
||||||
|
is_multimodal: false,
|
||||||
|
description: '',
|
||||||
|
permission: DatasetPermission.allTeamMembers,
|
||||||
|
indexing_technique: IndexingType.ECONOMICAL,
|
||||||
|
retrieval_model_dict: {
|
||||||
|
search_method: RETRIEVE_METHOD.fullText,
|
||||||
|
top_k: 5,
|
||||||
|
reranking_enable: false,
|
||||||
|
reranking_model: {
|
||||||
|
reranking_model_name: '',
|
||||||
|
reranking_provider_name: '',
|
||||||
|
},
|
||||||
|
score_threshold_enabled: false,
|
||||||
|
score_threshold: 0,
|
||||||
|
},
|
||||||
|
...overrides,
|
||||||
|
} as DataSet)
|
||||||
|
|
||||||
|
describe('SelectDataSet', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders dataset entries, allows selection, and fires onSelect', async () => {
|
||||||
|
const datasetOne = makeDataset({
|
||||||
|
id: 'set-1',
|
||||||
|
name: 'Dataset One',
|
||||||
|
is_multimodal: true,
|
||||||
|
indexing_technique: IndexingType.ECONOMICAL,
|
||||||
|
})
|
||||||
|
const datasetTwo = makeDataset({
|
||||||
|
id: 'set-2',
|
||||||
|
name: 'Hidden Dataset',
|
||||||
|
embedding_available: false,
|
||||||
|
provider: 'external',
|
||||||
|
})
|
||||||
|
mockUseInfiniteDatasets.mockReturnValue({
|
||||||
|
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
|
||||||
|
isLoading: false,
|
||||||
|
isFetchingNextPage: false,
|
||||||
|
fetchNextPage: vi.fn(),
|
||||||
|
hasNextPage: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
const onSelect = vi.fn()
|
||||||
|
await act(async () => {
|
||||||
|
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(screen.getByText('Dataset One')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(screen.getByText('Dataset One'))
|
||||||
|
})
|
||||||
|
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
|
||||||
|
|
||||||
|
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(addButton)
|
||||||
|
})
|
||||||
|
expect(onSelect).toHaveBeenCalledWith([datasetOne])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('shows empty state when no datasets are available and disables add', async () => {
|
||||||
|
mockUseInfiniteDatasets.mockReturnValue({
|
||||||
|
data: { pages: [{ data: [] }] },
|
||||||
|
isLoading: false,
|
||||||
|
isFetchingNextPage: false,
|
||||||
|
fetchNextPage: vi.fn(),
|
||||||
|
hasNextPage: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
|
||||||
|
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
|
||||||
|
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,125 @@
|
||||||
|
import type { IPromptValuePanelProps } from './index'
|
||||||
|
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||||
|
import * as React from 'react'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { useStore } from '@/app/components/app/store'
|
||||||
|
import ConfigContext from '@/context/debug-configuration'
|
||||||
|
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
|
||||||
|
import PromptValuePanel from './index'
|
||||||
|
|
||||||
|
vi.mock('@/app/components/app/store', () => ({
|
||||||
|
useStore: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
|
||||||
|
<button type="button" onClick={onFeatureBarClick}>
|
||||||
|
feature bar
|
||||||
|
</button>
|
||||||
|
),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockSetShowAppConfigureFeaturesModal = vi.fn()
|
||||||
|
const mockUseStore = vi.mocked(useStore)
|
||||||
|
const mockSetInputs = vi.fn()
|
||||||
|
const mockOnSend = vi.fn()
|
||||||
|
|
||||||
|
const promptVariables = [
|
||||||
|
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
|
||||||
|
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
|
||||||
|
] as const
|
||||||
|
|
||||||
|
const baseContextValue: any = {
|
||||||
|
modelModeType: ModelModeType.completion,
|
||||||
|
modelConfig: {
|
||||||
|
configs: {
|
||||||
|
prompt_template: 'prompt template',
|
||||||
|
prompt_variables: promptVariables,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
setInputs: mockSetInputs,
|
||||||
|
mode: AppModeEnum.COMPLETION,
|
||||||
|
isAdvancedMode: false,
|
||||||
|
completionPromptConfig: {
|
||||||
|
prompt: { text: 'completion' },
|
||||||
|
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
|
||||||
|
},
|
||||||
|
chatPromptConfig: { prompt: [] },
|
||||||
|
} as any
|
||||||
|
|
||||||
|
const defaultProps: IPromptValuePanelProps = {
|
||||||
|
appType: AppModeEnum.COMPLETION,
|
||||||
|
onSend: mockOnSend,
|
||||||
|
inputs: { textVar: 'initial', boolVar: false },
|
||||||
|
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
|
||||||
|
onVisionFilesChange: vi.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderPanel = (options: {
|
||||||
|
context?: Partial<typeof baseContextValue>
|
||||||
|
props?: Partial<IPromptValuePanelProps>
|
||||||
|
} = {}) => {
|
||||||
|
const contextValue = { ...baseContextValue, ...options.context }
|
||||||
|
const props = { ...defaultProps, ...options.props }
|
||||||
|
return render(
|
||||||
|
<ConfigContext.Provider value={contextValue}>
|
||||||
|
<PromptValuePanel {...props} />
|
||||||
|
</ConfigContext.Provider>,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('PromptValuePanel', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
mockUseStore.mockImplementation(selector => selector({
|
||||||
|
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
|
||||||
|
appSidebarExpand: '',
|
||||||
|
currentLogModalActiveTab: 'prompt',
|
||||||
|
showPromptLogModal: false,
|
||||||
|
showAgentLogModal: false,
|
||||||
|
setShowPromptLogModal: vi.fn(),
|
||||||
|
setShowAgentLogModal: vi.fn(),
|
||||||
|
showMessageLogModal: false,
|
||||||
|
showAppConfigureFeaturesModal: false,
|
||||||
|
} as any))
|
||||||
|
mockSetInputs.mockClear()
|
||||||
|
mockOnSend.mockClear()
|
||||||
|
mockSetShowAppConfigureFeaturesModal.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('updates inputs, clears values, and triggers run when ready', async () => {
|
||||||
|
renderPanel()
|
||||||
|
|
||||||
|
const textInput = screen.getByPlaceholderText('Text Var')
|
||||||
|
fireEvent.change(textInput, { target: { value: 'updated' } })
|
||||||
|
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
|
||||||
|
|
||||||
|
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
|
||||||
|
fireEvent.click(clearButton)
|
||||||
|
|
||||||
|
expect(mockSetInputs).toHaveBeenLastCalledWith({
|
||||||
|
textVar: '',
|
||||||
|
boolVar: '',
|
||||||
|
})
|
||||||
|
|
||||||
|
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||||
|
expect(runButton).not.toBeDisabled()
|
||||||
|
fireEvent.click(runButton)
|
||||||
|
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('disables run when mode is not completion', () => {
|
||||||
|
renderPanel({
|
||||||
|
context: {
|
||||||
|
mode: AppModeEnum.CHAT,
|
||||||
|
},
|
||||||
|
props: {
|
||||||
|
appType: AppModeEnum.CHAT,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||||
|
expect(runButton).toBeDisabled()
|
||||||
|
fireEvent.click(runButton)
|
||||||
|
expect(mockOnSend).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
import type { PromptVariable } from '@/models/debug'
|
||||||
|
|
||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
import { replaceStringWithValues } from './utils'
|
||||||
|
|
||||||
|
const promptVariables: PromptVariable[] = [
|
||||||
|
{ key: 'user', name: 'User', type: 'string' },
|
||||||
|
{ key: 'topic', name: 'Topic', type: 'string' },
|
||||||
|
]
|
||||||
|
|
||||||
|
describe('replaceStringWithValues', () => {
|
||||||
|
it('should replace placeholders when inputs have values', () => {
|
||||||
|
const template = 'Hello {{user}} talking about {{topic}}'
|
||||||
|
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
|
||||||
|
expect(result).toBe('Hello Alice talking about cats')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use prompt variable name when value is missing', () => {
|
||||||
|
const template = 'Hi {{user}} from {{topic}}'
|
||||||
|
const result = replaceStringWithValues(template, promptVariables, {})
|
||||||
|
expect(result).toBe('Hi {{User}} from {{Topic}}')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should leave placeholder untouched when no variable is defined', () => {
|
||||||
|
const template = 'Unknown {{missing}} placeholder'
|
||||||
|
const result = replaceStringWithValues(template, promptVariables, {})
|
||||||
|
expect(result).toBe('Unknown {{missing}} placeholder')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||||
|
import { useRouter } from 'next/navigation'
|
||||||
|
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { trackEvent } from '@/app/components/base/amplitude'
|
||||||
|
|
||||||
|
import { ToastContext } from '@/app/components/base/toast'
|
||||||
|
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||||
|
import { useAppContext } from '@/context/app-context'
|
||||||
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
|
import { createApp } from '@/service/apps'
|
||||||
|
import { AppModeEnum } from '@/types/app'
|
||||||
|
import { getRedirection } from '@/utils/app-redirection'
|
||||||
|
import CreateAppModal from './index'
|
||||||
|
|
||||||
|
vi.mock('ahooks', () => ({
|
||||||
|
useDebounceFn: (fn: (...args: any[]) => any) => {
|
||||||
|
const run = (...args: any[]) => fn(...args)
|
||||||
|
const cancel = vi.fn()
|
||||||
|
const flush = vi.fn()
|
||||||
|
return { run, cancel, flush }
|
||||||
|
},
|
||||||
|
useKeyPress: vi.fn(),
|
||||||
|
useHover: () => false,
|
||||||
|
}))
|
||||||
|
vi.mock('next/navigation', () => ({
|
||||||
|
useRouter: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/app/components/base/amplitude', () => ({
|
||||||
|
trackEvent: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/service/apps', () => ({
|
||||||
|
createApp: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/utils/app-redirection', () => ({
|
||||||
|
getRedirection: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/provider-context', () => ({
|
||||||
|
useProviderContext: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/app-context', () => ({
|
||||||
|
useAppContext: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/i18n', () => ({
|
||||||
|
useDocLink: () => () => '/guides',
|
||||||
|
}))
|
||||||
|
vi.mock('@/hooks/use-theme', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
default: () => ({ theme: 'light' }),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockNotify = vi.fn()
|
||||||
|
const mockUseRouter = vi.mocked(useRouter)
|
||||||
|
const mockPush = vi.fn()
|
||||||
|
const mockCreateApp = vi.mocked(createApp)
|
||||||
|
const mockTrackEvent = vi.mocked(trackEvent)
|
||||||
|
const mockGetRedirection = vi.mocked(getRedirection)
|
||||||
|
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||||
|
const mockUseAppContext = vi.mocked(useAppContext)
|
||||||
|
|
||||||
|
const defaultPlanUsage = {
|
||||||
|
buildApps: 0,
|
||||||
|
teamMembers: 0,
|
||||||
|
annotatedResponse: 0,
|
||||||
|
documentsUploadQuota: 0,
|
||||||
|
apiRateLimit: 0,
|
||||||
|
triggerEvents: 0,
|
||||||
|
vectorSpace: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderModal = () => {
|
||||||
|
const onClose = vi.fn()
|
||||||
|
const onSuccess = vi.fn()
|
||||||
|
render(
|
||||||
|
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||||
|
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
|
||||||
|
</ToastContext.Provider>,
|
||||||
|
)
|
||||||
|
return { onClose, onSuccess }
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('CreateAppModal', () => {
|
||||||
|
const mockSetItem = vi.fn()
|
||||||
|
const originalLocalStorage = window.localStorage
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockUseRouter.mockReturnValue({ push: mockPush } as any)
|
||||||
|
mockUseProviderContext.mockReturnValue({
|
||||||
|
plan: {
|
||||||
|
type: AppModeEnum.ADVANCED_CHAT,
|
||||||
|
usage: defaultPlanUsage,
|
||||||
|
total: { ...defaultPlanUsage, buildApps: 1 },
|
||||||
|
reset: {},
|
||||||
|
},
|
||||||
|
enableBilling: true,
|
||||||
|
} as any)
|
||||||
|
mockUseAppContext.mockReturnValue({
|
||||||
|
isCurrentWorkspaceEditor: true,
|
||||||
|
} as any)
|
||||||
|
mockSetItem.mockClear()
|
||||||
|
Object.defineProperty(window, 'localStorage', {
|
||||||
|
value: {
|
||||||
|
setItem: mockSetItem,
|
||||||
|
getItem: vi.fn(),
|
||||||
|
removeItem: vi.fn(),
|
||||||
|
clear: vi.fn(),
|
||||||
|
key: vi.fn(),
|
||||||
|
length: 0,
|
||||||
|
},
|
||||||
|
writable: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
Object.defineProperty(window, 'localStorage', {
|
||||||
|
value: originalLocalStorage,
|
||||||
|
writable: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('creates an app, notifies success, and fires callbacks', async () => {
|
||||||
|
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
|
||||||
|
mockCreateApp.mockResolvedValue(mockApp as any)
|
||||||
|
const { onClose, onSuccess } = renderModal()
|
||||||
|
|
||||||
|
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||||
|
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||||
|
|
||||||
|
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||||
|
name: 'My App',
|
||||||
|
description: '',
|
||||||
|
icon_type: 'emoji',
|
||||||
|
icon: '🤖',
|
||||||
|
icon_background: '#FFEAD5',
|
||||||
|
mode: AppModeEnum.ADVANCED_CHAT,
|
||||||
|
}))
|
||||||
|
|
||||||
|
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
|
||||||
|
app_mode: AppModeEnum.ADVANCED_CHAT,
|
||||||
|
description: '',
|
||||||
|
})
|
||||||
|
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
|
||||||
|
expect(onSuccess).toHaveBeenCalled()
|
||||||
|
expect(onClose).toHaveBeenCalled()
|
||||||
|
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
|
||||||
|
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('shows error toast when creation fails', async () => {
|
||||||
|
mockCreateApp.mockRejectedValue(new Error('boom'))
|
||||||
|
const { onClose } = renderModal()
|
||||||
|
|
||||||
|
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||||
|
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||||
|
|
||||||
|
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||||
|
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||||
|
expect(onClose).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
import type { SiteInfo } from '@/models/share'
|
||||||
|
import { fireEvent, render, screen } from '@testing-library/react'
|
||||||
|
import copy from 'copy-to-clipboard'
|
||||||
|
import * as React from 'react'
|
||||||
|
|
||||||
|
import { act } from 'react'
|
||||||
|
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import Embedded from './index'
|
||||||
|
|
||||||
|
vi.mock('./style.module.css', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
default: {
|
||||||
|
option: 'option',
|
||||||
|
active: 'active',
|
||||||
|
iframeIcon: 'iframeIcon',
|
||||||
|
scriptsIcon: 'scriptsIcon',
|
||||||
|
chromePluginIcon: 'chromePluginIcon',
|
||||||
|
pluginInstallIcon: 'pluginInstallIcon',
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
const mockThemeBuilder = {
|
||||||
|
buildTheme: vi.fn(),
|
||||||
|
theme: {
|
||||||
|
primaryColor: '#123456',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const mockUseAppContext = vi.fn(() => ({
|
||||||
|
langGeniusVersionInfo: {
|
||||||
|
current_env: 'PRODUCTION',
|
||||||
|
current_version: '',
|
||||||
|
latest_version: '',
|
||||||
|
release_date: '',
|
||||||
|
release_notes: '',
|
||||||
|
version: '',
|
||||||
|
can_auto_update: false,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('copy-to-clipboard', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
default: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
|
||||||
|
useThemeContext: () => mockThemeBuilder,
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/app-context', () => ({
|
||||||
|
useAppContext: () => mockUseAppContext(),
|
||||||
|
}))
|
||||||
|
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||||
|
const mockedCopy = vi.mocked(copy)
|
||||||
|
|
||||||
|
const siteInfo: SiteInfo = {
|
||||||
|
title: 'test site',
|
||||||
|
chat_color_theme: '#000000',
|
||||||
|
chat_color_theme_inverted: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseProps = {
|
||||||
|
isShow: true,
|
||||||
|
siteInfo,
|
||||||
|
onClose: vi.fn(),
|
||||||
|
appBaseUrl: 'https://app.example.com',
|
||||||
|
accessToken: 'token',
|
||||||
|
className: 'custom-modal',
|
||||||
|
}
|
||||||
|
|
||||||
|
const getCopyButton = () => {
|
||||||
|
const buttons = screen.getAllByRole('button')
|
||||||
|
const actionButton = buttons.find(button => button.className.includes('action-btn'))
|
||||||
|
expect(actionButton).toBeDefined()
|
||||||
|
return actionButton!
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('Embedded', () => {
|
||||||
|
afterEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockWindowOpen.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
mockWindowOpen.mockRestore()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('builds theme and copies iframe snippet', async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<Embedded {...baseProps} />)
|
||||||
|
})
|
||||||
|
|
||||||
|
const actionButton = getCopyButton()
|
||||||
|
const innerDiv = actionButton.querySelector('div')
|
||||||
|
act(() => {
|
||||||
|
fireEvent.click(innerDiv ?? actionButton)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
|
||||||
|
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('opens chrome plugin store link when chrome option selected', async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<Embedded {...baseProps} />)
|
||||||
|
})
|
||||||
|
|
||||||
|
const optionButtons = document.body.querySelectorAll('[class*="option"]')
|
||||||
|
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
|
||||||
|
act(() => {
|
||||||
|
fireEvent.click(optionButtons[2])
|
||||||
|
})
|
||||||
|
|
||||||
|
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
|
||||||
|
act(() => {
|
||||||
|
fireEvent.click(chromeText)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(mockWindowOpen).toHaveBeenCalledWith(
|
||||||
|
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
|
||||||
|
'_blank',
|
||||||
|
'noopener,noreferrer',
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
import type { ISavedItemsProps } from './index'
|
||||||
|
import { fireEvent, render, screen } from '@testing-library/react'
|
||||||
|
import copy from 'copy-to-clipboard'
|
||||||
|
|
||||||
|
import * as React from 'react'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import Toast from '@/app/components/base/toast'
|
||||||
|
import SavedItems from './index'
|
||||||
|
|
||||||
|
vi.mock('copy-to-clipboard', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
default: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('next/navigation', () => ({
|
||||||
|
useParams: () => ({}),
|
||||||
|
usePathname: () => '/',
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockCopy = vi.mocked(copy)
|
||||||
|
const toastNotifySpy = vi.spyOn(Toast, 'notify')
|
||||||
|
|
||||||
|
const baseProps: ISavedItemsProps = {
|
||||||
|
list: [
|
||||||
|
{ id: '1', answer: 'hello world' },
|
||||||
|
],
|
||||||
|
isShowTextToSpeech: true,
|
||||||
|
onRemove: vi.fn(),
|
||||||
|
onStartCreateContent: vi.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('SavedItems', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
toastNotifySpy.mockClear()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders saved answers with metadata and controls', () => {
|
||||||
|
const { container } = render(<SavedItems {...baseProps} />)
|
||||||
|
|
||||||
|
const markdownElement = container.querySelector('.markdown-body')
|
||||||
|
expect(markdownElement).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
|
||||||
|
|
||||||
|
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||||
|
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||||
|
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('copies content and notifies, and triggers remove callback', () => {
|
||||||
|
const handleRemove = vi.fn()
|
||||||
|
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
|
||||||
|
|
||||||
|
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||||
|
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||||
|
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||||
|
|
||||||
|
const copyButton = actionButtons[1]
|
||||||
|
const deleteButton = actionButtons[2]
|
||||||
|
|
||||||
|
fireEvent.click(copyButton)
|
||||||
|
expect(mockCopy).toHaveBeenCalledWith('hello world')
|
||||||
|
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
|
||||||
|
|
||||||
|
fireEvent.click(deleteButton)
|
||||||
|
expect(handleRemove).toHaveBeenCalledWith('1')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
import { fireEvent, render, screen } from '@testing-library/react'
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import NoData from './index'
|
||||||
|
|
||||||
|
describe('NoData', () => {
|
||||||
|
it('renders title/description and calls callback when button clicked', () => {
|
||||||
|
const handleStart = vi.fn()
|
||||||
|
render(<NoData onStartCreateContent={handleStart} />)
|
||||||
|
|
||||||
|
const title = screen.getByText('share.generation.savedNoData.title')
|
||||||
|
const description = screen.getByText('share.generation.savedNoData.description')
|
||||||
|
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
|
||||||
|
|
||||||
|
expect(title).toBeInTheDocument()
|
||||||
|
expect(description).toBeInTheDocument()
|
||||||
|
expect(button).toBeInTheDocument()
|
||||||
|
|
||||||
|
fireEvent.click(button)
|
||||||
|
expect(handleStart).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils'
|
||||||
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
|
import { Plan } from '@/app/components/billing/type'
|
||||||
|
import { useAppContext } from '@/context/app-context'
|
||||||
|
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||||
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
|
import { updateCurrentWorkspace } from '@/service/common'
|
||||||
|
import CustomWebAppBrand from './index'
|
||||||
|
|
||||||
|
vi.mock('@/app/components/base/toast', () => ({
|
||||||
|
useToastContext: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/service/common', () => ({
|
||||||
|
updateCurrentWorkspace: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/app-context', () => ({
|
||||||
|
useAppContext: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/provider-context', () => ({
|
||||||
|
useProviderContext: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/context/global-public-context', () => ({
|
||||||
|
useGlobalPublicStore: vi.fn(),
|
||||||
|
}))
|
||||||
|
vi.mock('@/app/components/base/image-uploader/utils', () => ({
|
||||||
|
imageUpload: vi.fn(),
|
||||||
|
getImageUploadErrorMessage: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockNotify = vi.fn()
|
||||||
|
const mockUseToastContext = vi.mocked(useToastContext)
|
||||||
|
const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace)
|
||||||
|
const mockUseAppContext = vi.mocked(useAppContext)
|
||||||
|
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||||
|
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
|
||||||
|
const mockImageUpload = vi.mocked(imageUpload)
|
||||||
|
const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage)
|
||||||
|
|
||||||
|
const defaultPlanUsage = {
|
||||||
|
buildApps: 0,
|
||||||
|
teamMembers: 0,
|
||||||
|
annotatedResponse: 0,
|
||||||
|
documentsUploadQuota: 0,
|
||||||
|
apiRateLimit: 0,
|
||||||
|
triggerEvents: 0,
|
||||||
|
vectorSpace: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderComponent = () => render(<CustomWebAppBrand />)
|
||||||
|
|
||||||
|
describe('CustomWebAppBrand', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockUseToastContext.mockReturnValue({ notify: mockNotify } as any)
|
||||||
|
mockUpdateCurrentWorkspace.mockResolvedValue({} as any)
|
||||||
|
mockUseAppContext.mockReturnValue({
|
||||||
|
currentWorkspace: {
|
||||||
|
custom_config: {
|
||||||
|
replace_webapp_logo: 'https://example.com/replace.png',
|
||||||
|
remove_webapp_brand: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
mutateCurrentWorkspace: vi.fn(),
|
||||||
|
isCurrentWorkspaceManager: true,
|
||||||
|
} as any)
|
||||||
|
mockUseProviderContext.mockReturnValue({
|
||||||
|
plan: {
|
||||||
|
type: Plan.professional,
|
||||||
|
usage: defaultPlanUsage,
|
||||||
|
total: defaultPlanUsage,
|
||||||
|
reset: {},
|
||||||
|
},
|
||||||
|
enableBilling: false,
|
||||||
|
} as any)
|
||||||
|
const systemFeaturesState = {
|
||||||
|
branding: {
|
||||||
|
enabled: true,
|
||||||
|
workspace_logo: 'https://example.com/workspace-logo.png',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mockUseGlobalPublicStore.mockImplementation(selector => selector ? selector({ systemFeatures: systemFeaturesState } as any) : { systemFeatures: systemFeaturesState })
|
||||||
|
mockGetImageUploadErrorMessage.mockReturnValue('upload error')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('disables upload controls when the user cannot manage the workspace', () => {
|
||||||
|
mockUseAppContext.mockReturnValue({
|
||||||
|
currentWorkspace: {
|
||||||
|
custom_config: {
|
||||||
|
replace_webapp_logo: '',
|
||||||
|
remove_webapp_brand: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
mutateCurrentWorkspace: vi.fn(),
|
||||||
|
isCurrentWorkspaceManager: false,
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
const { container } = renderComponent()
|
||||||
|
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
||||||
|
expect(fileInput).toBeDisabled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('toggles remove brand switch and calls the backend + mutate', async () => {
|
||||||
|
const mutateMock = vi.fn()
|
||||||
|
mockUseAppContext.mockReturnValue({
|
||||||
|
currentWorkspace: {
|
||||||
|
custom_config: {
|
||||||
|
replace_webapp_logo: '',
|
||||||
|
remove_webapp_brand: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
mutateCurrentWorkspace: mutateMock,
|
||||||
|
isCurrentWorkspaceManager: true,
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
renderComponent()
|
||||||
|
const switchInput = screen.getByRole('switch')
|
||||||
|
fireEvent.click(switchInput)
|
||||||
|
|
||||||
|
await waitFor(() => expect(mockUpdateCurrentWorkspace).toHaveBeenCalledWith({
|
||||||
|
url: '/workspaces/custom-config',
|
||||||
|
body: { remove_webapp_brand: true },
|
||||||
|
}))
|
||||||
|
await waitFor(() => expect(mutateMock).toHaveBeenCalled())
|
||||||
|
})
|
||||||
|
|
||||||
|
it('shows cancel/apply buttons after successful upload and cancels properly', async () => {
|
||||||
|
mockImageUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }) => {
|
||||||
|
onProgressCallback(50)
|
||||||
|
onSuccessCallback({ id: 'new-logo' })
|
||||||
|
})
|
||||||
|
|
||||||
|
const { container } = renderComponent()
|
||||||
|
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
||||||
|
const testFile = new File(['content'], 'logo.png', { type: 'image/png' })
|
||||||
|
fireEvent.change(fileInput, { target: { files: [testFile] } })
|
||||||
|
|
||||||
|
await waitFor(() => expect(mockImageUpload).toHaveBeenCalled())
|
||||||
|
await waitFor(() => screen.getByRole('button', { name: 'custom.apply' }))
|
||||||
|
|
||||||
|
const cancelButton = screen.getByRole('button', { name: 'common.operation.cancel' })
|
||||||
|
fireEvent.click(cancelButton)
|
||||||
|
|
||||||
|
await waitFor(() => expect(screen.queryByRole('button', { name: 'custom.apply' })).toBeNull())
|
||||||
|
})
|
||||||
|
})
|
||||||
Loading…
Reference in New Issue