mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/hitl-frontend
This commit is contained in:
commit
c716c4ccbe
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"enabledPlugins": {
|
||||
"feature-dev@claude-plugins-official": true,
|
||||
"context7@claude-plugins-official": true,
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true
|
||||
}
|
||||
}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
34
.mcp.json
34
.mcp.json
|
|
@ -1,34 +0,0 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -313,17 +313,20 @@ class StreamableHTTPTransport:
|
|||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||
# The server MUST NOT send a response to notifications.
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
|
@ -381,10 +381,9 @@ class RetrievalService:
|
|||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids = []
|
||||
image_doc_ids: list[Any] = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
|
|
@ -417,28 +416,39 @@ class RetrievalService:
|
|||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||
index_node_ids = [i for i in index_node_ids if i]
|
||||
|
||||
segment_ids = []
|
||||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map = {}
|
||||
child_chunk_map = {}
|
||||
doc_segment_map = {}
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
|
||||
for attachment in attachments:
|
||||
segment_ids.append(attachment["segment_id"])
|
||||
attachment_map[attachment["segment_id"]] = attachment
|
||||
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
||||
|
||||
if attachment["segment_id"] in attachment_map:
|
||||
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
|
||||
else:
|
||||
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
|
||||
if attachment["segment_id"] in doc_segment_map:
|
||||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
segment_ids.append(i.segment_id)
|
||||
child_chunk_map[i.segment_id] = i
|
||||
doc_segment_map[i.segment_id] = i.index_node_id
|
||||
if i.segment_id in child_chunk_map:
|
||||
child_chunk_map[i.segment_id].append(i)
|
||||
else:
|
||||
child_chunk_map[i.segment_id] = [i]
|
||||
if i.segment_id in doc_segment_map:
|
||||
doc_segment_map[i.segment_id].append(i.index_node_id)
|
||||
else:
|
||||
doc_segment_map[i.segment_id] = [i.index_node_id]
|
||||
|
||||
if index_node_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
|
|
@ -448,7 +458,7 @@ class RetrievalService:
|
|||
)
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
|
|
@ -461,95 +471,86 @@ class RetrievalService:
|
|||
segments.extend(index_node_segments)
|
||||
|
||||
for segment in segments:
|
||||
doc_id = doc_segment_map.get(segment.id)
|
||||
child_chunk = child_chunk_map.get(segment.id)
|
||||
attachment_info = attachment_map.get(segment.id)
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if doc_id:
|
||||
document = doc_to_document_map[doc_id]
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
||||
document.metadata.get("document_id")
|
||||
)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunk:
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
max_score = 0.0
|
||||
for child_chunk in child_chunks:
|
||||
document = doc_to_document_map[child_chunk.index_node_id]
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
child_chunk_details.append(child_chunk_detail)
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||
for attachment_info in attachment_infos:
|
||||
file_document = doc_to_document_map[attachment_info["id"]]
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
|
||||
map_detail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
records.append(record)
|
||||
else:
|
||||
if child_chunk:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"],
|
||||
document.metadata.get("score", 0.0) if document else 0.0,
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
if segment.id in segment_file_map:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score", 0.0), # type: ignore
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
records.append(record)
|
||||
else:
|
||||
if attachment_info:
|
||||
attachment_infos = segment_file_map.get(segment.id, [])
|
||||
if attachment_info not in attachment_infos:
|
||||
attachment_infos.append(attachment_info)
|
||||
segment_file_map[segment.id] = attachment_infos
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record: dict[str, Any] = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
max_score = 0.0
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
for attachment_info in attachment_infos:
|
||||
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
if record["segment"].id in segment_file_map:
|
||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
result = []
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
# Extract segment
|
||||
segment = record["segment"]
|
||||
|
||||
# Extract child_chunks, ensuring it's a list or None
|
||||
child_chunks = record.get("child_chunks")
|
||||
if not isinstance(child_chunks, list):
|
||||
child_chunks = None
|
||||
raw_child_chunks = record.get("child_chunks")
|
||||
child_chunks_list: list[RetrievalChildChunk] | None = None
|
||||
if isinstance(raw_child_chunks, list):
|
||||
# Sort by score descending
|
||||
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
child_chunks_list = [
|
||||
RetrievalChildChunk(
|
||||
id=chunk["id"],
|
||||
content=chunk["content"],
|
||||
score=chunk.get("score", 0.0),
|
||||
position=chunk["position"],
|
||||
)
|
||||
for chunk in sorted_chunks
|
||||
]
|
||||
|
||||
# Extract files, ensuring it's a list or None
|
||||
files = record.get("files")
|
||||
|
|
@ -566,11 +567,11 @@ class RetrievalService:
|
|||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
return result
|
||||
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -255,7 +255,10 @@ class PGVector(BaseVector):
|
|||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
|
||||
if not cur.fetchone():
|
||||
cur.execute("CREATE EXTENSION vector")
|
||||
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# PG hnsw index only support 2000 dimension or less
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
|||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
|
|
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
|
|||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self._process_metadata_filter_func(
|
||||
self.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition"), # type: ignore
|
||||
filter.get("metadata_name"), # type: ignore
|
||||
|
|
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
|
|||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
filters = self.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
|
|
@ -1168,8 +1168,9 @@ class DatasetRetrieval:
|
|||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
@classmethod
|
||||
def process_metadata_filter_func(
|
||||
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
|
@ -1218,6 +1219,20 @@ class DatasetRetrieval:
|
|||
|
||||
case "≥" | ">=":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case "in" | "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
# `field in []` is False, `field not in []` is True
|
||||
filters.append(literal(condition == "not in"))
|
||||
else:
|
||||
op = json_field.in_ if condition == "in" else json_field.notin_
|
||||
filters.append(op(value_list))
|
||||
case _:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,15 @@ from typing import Any
|
|||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
|
|
@ -53,10 +61,19 @@ class MCPTool(Tool):
|
|||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
elif isinstance(content, ImageContent | AudioContent):
|
||||
yield self.create_blob_message(
|
||||
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||
)
|
||||
elif isinstance(content, EmbeddedResource):
|
||||
resource = content.resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
yield self.create_text_message(resource.text)
|
||||
elif isinstance(resource, BlobResourceContents):
|
||||
mime_type = resource.mimeType or "application/octet-stream"
|
||||
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
|
||||
else:
|
||||
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
|
|
@ -101,14 +118,6 @@ class MCPTool(Tool):
|
|||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
|
|
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self._process_metadata_filter_func(
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
|
|
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
|
|
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from enums.quota_type import QuotaType, unlimited
|
|||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from models.model import App, EndUser
|
|||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
|
|
@ -141,7 +141,7 @@ class AsyncWorkflowService:
|
|||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
raise InvokeRateLimitError(
|
||||
raise WorkflowQuotaLimitError(
|
||||
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
|
||||
) from e
|
||||
|
||||
|
|
|
|||
|
|
@ -110,5 +110,5 @@ class EnterpriseService:
|
|||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
|
||||
body = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class InvokeRateLimitError(Exception):
|
||||
"""Raised when rate limit is exceeded for workflow invocations."""
|
||||
class WorkflowQuotaLimitError(Exception):
|
||||
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
|
||||
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class PluginParameterService:
|
|||
provider,
|
||||
action,
|
||||
resolved_credentials,
|
||||
CredentialType.API_KEY.value,
|
||||
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
|
||||
parameter,
|
||||
)
|
||||
.options
|
||||
|
|
|
|||
|
|
@ -868,48 +868,111 @@ class TriggerProviderService:
|
|||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
# Get subscription within the transaction
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
# Decrypt existing credentials for merging
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
|
||||
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
|
||||
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
|
||||
merged_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
# Delete the previous subscription
|
||||
user_id = subscription.user_id
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
user_id = subscription.user_id
|
||||
|
||||
# Create a new subscription with the same subscription_id and endpoint_id
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription.id,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
credentials=credentials,
|
||||
properties=new_subscription.properties,
|
||||
expires_at=new_subscription.expires_at,
|
||||
)
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
|
||||
# FALLBACK: If the update api is not implemented,
|
||||
# delete the previous subscription and create a new one
|
||||
|
||||
# Unsubscribe the previous subscription (external call, but we'll handle errors)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
|
||||
# Continue anyway - the subscription might already be deleted externally
|
||||
|
||||
# Create a new subscription with the same subscription_id and endpoint_id (external call)
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=merged_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# Update the subscription in the same transaction
|
||||
# Inline update logic to reuse the same session
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing and existing.id != subscription.id:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
subscription.name = name
|
||||
|
||||
# Update parameters
|
||||
subscription.parameters = dict(parameters)
|
||||
|
||||
# Update credentials with merged (and encrypted) values
|
||||
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
|
||||
|
||||
# Update properties
|
||||
if new_subscription.properties:
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
|
||||
|
||||
# Update expiration timestamp
|
||||
if new_subscription.expires_at is not None:
|
||||
subscription.expires_at = new_subscription.expires_at
|
||||
|
||||
# Commit the transaction
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on any error
|
||||
session.rollback()
|
||||
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -863,10 +863,18 @@ class WebhookService:
|
|||
not_found_in_cache.append(node_id)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# lock the concurrent webhook trigger creation
|
||||
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
|
||||
lock = redis_client.lock(lock_key, timeout=10)
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
# acquire the lock with blocking and timeout
|
||||
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
|
||||
if not lock_acquired:
|
||||
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
|
||||
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowWebhookTrigger).where(
|
||||
|
|
@ -903,11 +911,16 @@ class WebhookService:
|
|||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
# release the lock only if it was acquired
|
||||
if lock_acquired:
|
||||
try:
|
||||
lock.release()
|
||||
except Exception:
|
||||
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_id(cls) -> str:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,682 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
|
||||
class TestTriggerProviderService:
|
||||
"""Integration tests for TriggerProviderService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
|
||||
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
|
||||
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_controller = MagicMock()
|
||||
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
|
||||
mock_provider_controller.get_properties_schema.return_value = MagicMock()
|
||||
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
|
||||
|
||||
# Mock redis lock
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock(return_value=None)
|
||||
mock_lock.__exit__ = MagicMock(return_value=None)
|
||||
mock_redis_client.lock.return_value = mock_lock
|
||||
|
||||
# Setup account feature service mock
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
yield {
|
||||
"trigger_manager": mock_trigger_manager,
|
||||
"redis_client": mock_redis_client,
|
||||
"delete_cache": mock_delete_cache,
|
||||
"provider_controller": mock_provider_controller,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies[
|
||||
"trigger_manager"
|
||||
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_subscription(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
tenant_id,
|
||||
user_id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
credentials,
|
||||
mock_external_service_dependencies,
|
||||
):
|
||||
"""
|
||||
Helper method to create a test trigger subscription.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID
|
||||
provider_id: Provider ID
|
||||
credential_type: Credential type
|
||||
credentials: Credentials dict
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
TriggerSubscription: Created subscription instance
|
||||
"""
|
||||
fake = Faker()
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
|
||||
# Use mock provider controller to encrypt credentials
|
||||
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||
|
||||
# Create encrypter for credentials
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
subscription = TriggerSubscription(
|
||||
name=fake.word(),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
endpoint_id=fake.uuid4(),
|
||||
parameters={"param1": "value1"},
|
||||
properties={"prop1": "value1"},
|
||||
credentials=dict(credential_encrypter.encrypt(credentials)),
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
db.session.add(subscription)
|
||||
db.session.commit()
|
||||
db.session.refresh(subscription)
|
||||
|
||||
return subscription
|
||||
|
||||
def test_rebuild_trigger_subscription_success_with_merged_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
|
||||
|
||||
This test verifies:
|
||||
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
|
||||
- Single transaction wraps all operations
|
||||
- Merged credentials are used for subscribe and update
|
||||
- Database state is correctly updated
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create initial subscription with credentials
|
||||
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
|
||||
# and new value for api_secret (should update)
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE, # Should be replaced with original
|
||||
"api_secret": "new-secret-value", # Should be updated
|
||||
}
|
||||
|
||||
# Mock subscribe_trigger to return a new subscription entity
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={"param1": "value1"},
|
||||
properties={"prop1": "new_prop_value"},
|
||||
expires_at=1234567890,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Mock unsubscribe_trigger
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={"param1": "updated_value"},
|
||||
name="updated_name",
|
||||
)
|
||||
|
||||
# Verify unsubscribe was called with decrypted original credentials
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
|
||||
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
|
||||
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
|
||||
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
|
||||
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
|
||||
|
||||
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
|
||||
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.name == "updated_name"
|
||||
assert subscription.parameters == {"param1": "updated_value"}
|
||||
|
||||
# Verify credentials in DB were updated with merged values (decrypt to check)
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
|
||||
# Use mock provider controller to decrypt credentials
|
||||
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
|
||||
|
||||
# Verify cache was cleared
|
||||
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_new_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are new (no HIDDEN_VALUE).
|
||||
|
||||
This test verifies:
|
||||
- All new credentials are used when no HIDDEN_VALUE is present
|
||||
- Merged credentials contain only new values
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create initial subscription
|
||||
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# All new credentials (no HIDDEN_VALUE)
|
||||
new_credentials = {
|
||||
"api_key": "completely-new-key",
|
||||
"api_secret": "completely-new-secret",
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with all new credentials
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == "completely-new-key"
|
||||
assert subscribe_credentials["api_secret"] == "completely-new-secret"
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_hidden_values(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
|
||||
|
||||
This test verifies:
|
||||
- All HIDDEN_VALUE credentials are replaced with existing values
|
||||
- Original credentials are preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# All HIDDEN_VALUE (should preserve all original)
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE,
|
||||
"api_secret": HIDDEN_VALUE,
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with all original credentials
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
|
||||
|
||||
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
|
||||
|
||||
This test verifies:
|
||||
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Original has only api_key
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE,
|
||||
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
|
||||
|
||||
def test_rebuild_trigger_subscription_rollback_on_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that transaction is rolled back on error.
|
||||
|
||||
This test verifies:
|
||||
- Database transaction is rolled back when an error occurs
|
||||
- Original subscription state is preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
original_name = subscription.name
|
||||
original_parameters = subscription.parameters.copy()
|
||||
|
||||
# Make subscribe_trigger raise an error
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
|
||||
"Subscribe failed"
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild and expect error
|
||||
with pytest.raises(ValueError, match="Subscribe failed"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscription state was not changed (rolled back)
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that unsubscribe errors are handled gracefully and operation continues.
|
||||
|
||||
This test verifies:
|
||||
- Unsubscribe errors are caught and logged but don't stop the rebuild
|
||||
- Rebuild continues even if unsubscribe fails
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Make unsubscribe_trigger raise an error (should be caught and continue)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
|
||||
"Unsubscribe failed"
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Execute rebuild - should succeed despite unsubscribe error
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was still called (operation continued)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
|
||||
# Verify subscription was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.parameters == {}
|
||||
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when subscription is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when subscription doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
fake_subscription_id = fake.uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake_subscription_id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when provider is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when provider doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
|
||||
|
||||
# Make get_trigger_provider return None
|
||||
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Provider.*not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake.uuid4(),
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_unsupported_credential_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when credential type is not supported for rebuild.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.UNAUTHORIZED # Not supported
|
||||
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that name uniqueness is checked when updating name.
|
||||
|
||||
This test verifies:
|
||||
- Error is raised when new name conflicts with existing subscription
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create first subscription
|
||||
subscription1 = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{"api_key": "key1"},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Create second subscription with different name
|
||||
subscription2 = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{"api_key": "key2"},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription2.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Try to rename subscription2 to subscription1's name (should fail)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription2.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
name=subscription1.name, # Conflicting name
|
||||
)
|
||||
|
|
@ -0,0 +1,327 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||
PGVector,
|
||||
PGVectorConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestPGVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=False,
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test PGVector initialization."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
assert pgvector._collection_name == self.collection_name
|
||||
assert pgvector.table_name == f"embedding_{self.collection_name}"
|
||||
assert pgvector.get_type() == "pgvector"
|
||||
assert pgvector.pool is not None
|
||||
assert pgvector.pg_bigm is False
|
||||
assert pgvector.index_hash is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init_with_pg_bigm(self, mock_pool_class):
|
||||
"""Test PGVector initialization with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
|
||||
assert pgvector.pg_bigm is True
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_basic(self, mock_redis, mock_pool_class):
|
||||
"""Test basic collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify SQL execution calls
|
||||
assert mock_cursor.execute.called
|
||||
|
||||
# Check that CREATE TABLE was called with correct dimension
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(1536)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that CREATE INDEX was called (dimension <= 2000)
|
||||
create_index_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
|
||||
]
|
||||
assert len(create_index_calls) == 1
|
||||
|
||||
# Verify Redis cache was set
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with dimension > 2000 (no HNSW index)."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(3072) # Dimension > 2000
|
||||
|
||||
# Check that CREATE TABLE was called
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(3072)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that HNSW index was NOT created (dimension > 2000)
|
||||
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
|
||||
assert len(hnsw_index_calls) == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that pg_bigm index was created
|
||||
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
|
||||
assert len(bigm_index_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
|
||||
"""Test that vector extension is created if it doesn't exist."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
# First call: vector extension doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that CREATE EXTENSION was called
|
||||
create_extension_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
|
||||
]
|
||||
assert len(create_extension_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
|
||||
"""Test that collection creation is skipped when cache exists."""
|
||||
# Mock Redis operations - cache exists
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = 1 # Cache exists
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that no SQL was executed (early return due to cache)
|
||||
assert mock_cursor.execute.call_count == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
|
||||
"""Test that Redis lock is used during collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify Redis lock was acquired with correct lock name
|
||||
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
|
||||
|
||||
# Verify lock context manager was entered and exited
|
||||
mock_lock.__enter__.assert_called_once()
|
||||
mock_lock.__exit__.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_get_cursor_context_manager(self, mock_pool_class):
|
||||
"""Test that _get_cursor properly manages connection lifecycle."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
with pgvector._get_cursor() as cur:
|
||||
assert cur == mock_cursor
|
||||
|
||||
# Verify connection lifecycle methods were called
|
||||
mock_pool.getconn.assert_called_once()
|
||||
mock_cursor.close.assert_called_once()
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_pool.putconn.assert_called_once_with(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config_override",
|
||||
[
|
||||
{"host": ""}, # Test empty host
|
||||
{"port": 0}, # Test invalid port
|
||||
{"user": ""}, # Test empty user
|
||||
{"password": ""}, # Test empty password
|
||||
{"database": ""}, # Test empty database
|
||||
{"min_connection": 0}, # Test invalid min_connection
|
||||
{"max_connection": 0}, # Test invalid max_connection
|
||||
{"min_connection": 10, "max_connection": 5}, # Test min > max
|
||||
],
|
||||
)
|
||||
def test_config_validation_parametrized(invalid_config_override):
|
||||
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
"database": "test_db",
|
||||
"min_connection": 1,
|
||||
"max_connection": 5,
|
||||
}
|
||||
config.update(invalid_config_override)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PGVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,873 @@
|
|||
"""
|
||||
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||
|
||||
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||
filter expressions based on metadata filtering conditions.
|
||||
|
||||
Conditions Tested:
|
||||
==================
|
||||
1. **String Conditions**: contains, not contains, start with, end with
|
||||
2. **Equality Conditions**: is / =, is not / ≠
|
||||
3. **Null Conditions**: empty, not empty
|
||||
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||
5. **List Conditions**: in
|
||||
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||
|
||||
Test Architecture:
|
||||
==================
|
||||
- Direct instantiation of DatasetRetrieval
|
||||
- Mocking of DatasetDocument model attributes
|
||||
- Verification of SQLAlchemy filter expressions
|
||||
- Follows Arrange-Act-Assert (AAA) pattern
|
||||
|
||||
Running Tests:
|
||||
==============
|
||||
# Run all tests in this module
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||
|
||||
# Run a specific test
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
|
||||
class TestProcessMetadataFilterFunc:
|
||||
"""
|
||||
Comprehensive test suite for process_metadata_filter_func method.
|
||||
|
||||
This test class validates all metadata filtering conditions supported by
|
||||
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||
null checks, and list operations.
|
||||
|
||||
Method Signature:
|
||||
==================
|
||||
def process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
) -> list:
|
||||
|
||||
The method builds SQLAlchemy filter expressions by:
|
||||
1. Validating value is not None (except for empty/not empty conditions)
|
||||
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||
4. Returning the updated filters list
|
||||
|
||||
Mocking Strategy:
|
||||
==================
|
||||
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||
- Verify filter expressions are created correctly
|
||||
- Test with various data types (str, int, float, list)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval(self):
|
||||
"""
|
||||
Create a DatasetRetrieval instance for testing.
|
||||
|
||||
Returns:
|
||||
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||
"""
|
||||
return DatasetRetrieval()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_metadata(self):
|
||||
"""
|
||||
Mock the DatasetDocument.doc_metadata JSON field.
|
||||
|
||||
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||
JSON fields. We mock this to avoid database dependencies.
|
||||
|
||||
Returns:
|
||||
Mock: Mocked doc_metadata attribute
|
||||
"""
|
||||
mock_metadata_field = MagicMock()
|
||||
|
||||
# Create mock for string access
|
||||
mock_string_access = MagicMock()
|
||||
mock_string_access.like = MagicMock()
|
||||
mock_string_access.notlike = MagicMock()
|
||||
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for float access (for numeric comparisons)
|
||||
mock_float_access = MagicMock()
|
||||
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for null checks
|
||||
mock_null_access = MagicMock()
|
||||
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Setup __getitem__ to return appropriate mock based on usage
|
||||
def getitem_side_effect(name):
|
||||
if name in ["author", "title", "category"]:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
# ==================== String Condition Tests ====================
|
||||
|
||||
def test_contains_condition_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'contains' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "John"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_contains_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with NOT LIKE expression
|
||||
- Pattern matching uses %value% syntax with negation
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not contains"
|
||||
metadata_name = "title"
|
||||
value = "banned"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_start_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'start with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "start with"
|
||||
metadata_name = "category"
|
||||
value = "tech"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_end_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'end with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "end with"
|
||||
metadata_name = "filename"
|
||||
value = ".pdf"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Equality Condition Tests ====================
|
||||
|
||||
def test_is_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' (=) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with equality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = "Jane Doe"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_equals_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test '=' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is' condition
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "="
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_int_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with integer value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_float_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with float value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "price"
|
||||
value = 19.99
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' (≠) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with inequality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "author"
|
||||
value = "Unknown"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test '≠' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is not' condition
|
||||
- Inequality expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≠"
|
||||
metadata_name = "category"
|
||||
value = "archived"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' condition with numeric value.
|
||||
|
||||
Verifies:
|
||||
- Numeric inequality comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Null Condition Tests ====================
|
||||
|
||||
def test_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'empty' condition (null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "empty"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not empty' condition (not null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NOT NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not empty"
|
||||
metadata_name = "description"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Numeric Comparison Tests ====================
|
||||
|
||||
def test_before_condition(self, retrieval):
|
||||
"""
|
||||
Test 'before' (<) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "before"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '<' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'before' condition
|
||||
- Less than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "price"
|
||||
value = 100.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_after_condition(self, retrieval):
|
||||
"""
|
||||
Test 'after' (>) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "after"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '>' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'after' condition
|
||||
- Greater than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≤' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≤"
|
||||
metadata_name = "price"
|
||||
value = 50.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '<=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≤' condition
|
||||
- Less than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<="
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≥' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≥"
|
||||
metadata_name = "rating"
|
||||
value = 3.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '>=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≥' condition
|
||||
- Greater than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== List/In Condition Tests ====================
|
||||
|
||||
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with comma-separated string value.
|
||||
|
||||
Verifies:
|
||||
- String is split into list
|
||||
- Whitespace is trimmed from each value
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "tech, science, AI "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_list_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with list value.
|
||||
|
||||
Verifies:
|
||||
- List is processed correctly
|
||||
- None values are filtered out
|
||||
- IN expression is created with valid values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "tags"
|
||||
value = ["python", "javascript", None, "golang"]
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_tuple_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with tuple value.
|
||||
|
||||
Verifies:
|
||||
- Tuple is processed like a list
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ("tech", "science", "ai")
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_empty_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with empty string value.
|
||||
|
||||
Verifies:
|
||||
- Empty string results in literal(False) filter
|
||||
- No valid values to match
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
# Verify it's a literal(False) expression
|
||||
# This is a bit tricky to test without access to the actual expression
|
||||
|
||||
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with whitespace-only string value.
|
||||
|
||||
Verifies:
|
||||
- Whitespace-only string results in literal(False) filter
|
||||
- All values are stripped and filtered out
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = " , , "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_single_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with single non-comma string.
|
||||
|
||||
Verifies:
|
||||
- Single string is treated as single-item list
|
||||
- IN expression is created with one value
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with conditions that require value.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values (except empty/not empty)
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0 # No filter added
|
||||
|
||||
def test_none_value_with_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with 'is' (=) condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_none_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "year"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_existing_filters_preserved(self, retrieval):
|
||||
"""
|
||||
Test that existing filters are preserved.
|
||||
|
||||
Verifies:
|
||||
- Existing filters in the list are not removed
|
||||
- New filters are appended to the list
|
||||
"""
|
||||
existing_filter = MagicMock()
|
||||
filters = [existing_filter]
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 2
|
||||
assert filters[0] == existing_filter
|
||||
|
||||
def test_multiple_filters_accumulated(self, retrieval):
|
||||
"""
|
||||
Test multiple calls to accumulate filters.
|
||||
|
||||
Verifies:
|
||||
- Each call adds a new filter to the list
|
||||
- All filters are preserved across calls
|
||||
"""
|
||||
filters = []
|
||||
|
||||
# First filter
|
||||
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||
assert len(filters) == 1
|
||||
|
||||
# Second filter
|
||||
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||
assert len(filters) == 2
|
||||
|
||||
# Third filter
|
||||
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||
assert len(filters) == 3
|
||||
|
||||
def test_unknown_condition(self, retrieval):
|
||||
"""
|
||||
Test unknown/unsupported condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for unknown conditions
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "unknown_condition"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_empty_string_value_with_contains(self, retrieval):
|
||||
"""
|
||||
Test empty string value with 'contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filter is added even with empty string
|
||||
- LIKE expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_special_characters_in_value(self, retrieval):
|
||||
"""
|
||||
Test special characters in value string.
|
||||
|
||||
Verifies:
|
||||
- Special characters are handled in value
|
||||
- LIKE expression is created correctly
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "title"
|
||||
value = "C++ & Python's features"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test zero value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Zero is treated as valid value
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "price"
|
||||
value = 0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test negative value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Negative numbers are handled correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "temperature"
|
||||
value = -10.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_float_value_with_integer_comparison(self, retrieval):
|
||||
"""
|
||||
Test float value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Float values work correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
import base64
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
|
||||
|
||||
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
|
||||
identity = ToolIdentity(
|
||||
author="test",
|
||||
name="test_mcp_tool",
|
||||
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
|
||||
provider="test_provider",
|
||||
)
|
||||
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
|
||||
runtime = Mock(spec=ToolRuntime)
|
||||
runtime.credentials = {}
|
||||
return MCPTool(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id="test_tenant",
|
||||
icon="",
|
||||
server_url="https://server.invalid",
|
||||
provider_id="provider_1",
|
||||
headers={},
|
||||
)
|
||||
|
||||
|
||||
class TestMCPToolInvoke:
|
||||
@pytest.mark.parametrize(
|
||||
("content_factory", "mime_type"),
|
||||
[
|
||||
(
|
||||
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
|
||||
"image/png",
|
||||
),
|
||||
(
|
||||
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
|
||||
"audio/mpeg",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
raw = b"\x00\x01test-bytes\x02"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
content = content_factory(b64, mime_type)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
|
||||
assert msg.message.blob == raw
|
||||
assert msg.meta == {"mime_type": mime_type}
|
||||
|
||||
def test_invoke_embedded_text_resource_yields_text(self) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
|
||||
content = EmbeddedResource(type="resource", resource=text_resource)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
|
||||
assert msg.message.text == "hello world"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "expected_mime"),
|
||||
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
|
||||
)
|
||||
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
raw = b"binary-data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
|
||||
content = EmbeddedResource(type="resource", resource=blob_resource)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
|
||||
assert msg.message.blob == raw
|
||||
assert msg.meta == {"mime_type": expected_mime}
|
||||
|
||||
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
|
||||
tool = _make_mcp_tool(output_schema={"type": "object"})
|
||||
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Expect two variable messages corresponding to keys a and b
|
||||
assert len(messages) == 2
|
||||
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
|
||||
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
|
||||
# Validate values
|
||||
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
|
||||
assert values == {"a": 1, "b": "x"}
|
||||
|
|
@ -3072,11 +3072,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "json-repair"
|
||||
version = "0.54.1"
|
||||
version = "0.54.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/46/d3a4d9a3dad39bb4a2ad16b8adb9fe2e8611b20b71197fe33daa6768e85d/json_repair-0.54.1.tar.gz", hash = "sha256:d010bc31f1fc66e7c36dc33bff5f8902674498ae5cb8e801ad455a53b455ad1d", size = 38555, upload-time = "2025-11-19T14:55:24.265Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/86/48b12ac02032f121ac7e5f11a32143edca6c1e3d19ffc54d6fb9ca0aafd0/json_repair-0.54.3.tar.gz", hash = "sha256:e50feec9725e52ac91f12184609754684ac1656119dfbd31de09bdaf9a1d8bf6", size = 38626, upload-time = "2025-12-15T09:41:58.594Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/db/96/c9aad7ee949cc1bf15df91f347fbc2d3bd10b30b80c7df689ce6fe9332b5/json_repair-0.54.1-py3-none-any.whl", hash = "sha256:016160c5db5d5fe443164927bb58d2dfbba5f43ad85719fa9bc51c713a443ab1", size = 29311, upload-time = "2025-11-19T14:55:22.886Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/08/abe317237add63c3e62f18a981bccf92112b431835b43d844aedaf61f4a0/json_repair-0.54.3-py3-none-any.whl", hash = "sha256:4cdc132ee27d4780576f71bf27a113877046224a808bfc17392e079cb344fb81", size = 29357, upload-time = "2025-12-15T09:41:57.436Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -54,17 +54,17 @@
|
|||
"publish:npm": "./scripts/publish.sh"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.3.5"
|
||||
"axios": "^1.13.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.2.0",
|
||||
"@types/node": "^20.11.30",
|
||||
"@eslint/js": "^9.39.2",
|
||||
"@types/node": "^25.0.3",
|
||||
"@typescript-eslint/eslint-plugin": "^8.50.1",
|
||||
"@typescript-eslint/parser": "^8.50.1",
|
||||
"@vitest/coverage-v8": "1.6.1",
|
||||
"eslint": "^9.2.0",
|
||||
"@vitest/coverage-v8": "4.0.16",
|
||||
"eslint": "^9.39.2",
|
||||
"tsup": "^8.5.1",
|
||||
"typescript": "^5.4.5",
|
||||
"vitest": "^1.5.0"
|
||||
"typescript": "^5.9.3",
|
||||
"vitest": "^4.0.16"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,2 @@
|
|||
onlyBuiltDependencies:
|
||||
- esbuild
|
||||
|
|
@ -47,6 +47,12 @@ const getCheckboxDefaultSelectValue = (value: InputVar['default']) => {
|
|||
const parseCheckboxSelectValue = (value: string) =>
|
||||
value === CHECKBOX_DEFAULT_TRUE_VALUE
|
||||
|
||||
const normalizeSelectDefaultValue = (inputVar: InputVar) => {
|
||||
if (inputVar.type === InputVarType.select && inputVar.default === '')
|
||||
return { ...inputVar, default: undefined }
|
||||
return inputVar
|
||||
}
|
||||
|
||||
export type IConfigModalProps = {
|
||||
isCreate?: boolean
|
||||
payload?: InputVar
|
||||
|
|
@ -67,7 +73,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
|||
}) => {
|
||||
const { modelConfig } = useContext(ConfigContext)
|
||||
const { t } = useTranslation()
|
||||
const [tempPayload, setTempPayload] = useState<InputVar>(() => payload || getNewVarInWorkflow('') as any)
|
||||
const [tempPayload, setTempPayload] = useState<InputVar>(() => normalizeSelectDefaultValue(payload || getNewVarInWorkflow('') as any))
|
||||
const { type, label, variable, options, max_length } = tempPayload
|
||||
const modalRef = useRef<HTMLDivElement>(null)
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
|
|
@ -182,6 +188,8 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
|||
|
||||
const newPayload = produce(tempPayload, (draft) => {
|
||||
draft.type = type
|
||||
if (type === InputVarType.select)
|
||||
draft.default = undefined
|
||||
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
|
||||
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
|
||||
if (key !== 'max_length')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,141 @@
|
|||
import type { DataSet } from '@/models/datasets'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import { DatasetPermission } from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import SelectDataSet from './index'
|
||||
|
||||
vi.mock('@/i18n-config/i18next-config', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
changeLanguage: vi.fn(),
|
||||
addResourceBundle: vi.fn(),
|
||||
use: vi.fn().mockReturnThis(),
|
||||
init: vi.fn(),
|
||||
addResource: vi.fn(),
|
||||
hasResourceBundle: vi.fn().mockReturnValue(true),
|
||||
},
|
||||
}))
|
||||
const mockUseInfiniteScroll = vi.fn()
|
||||
vi.mock('ahooks', async (importOriginal) => {
|
||||
const actual = await importOriginal()
|
||||
return {
|
||||
...(typeof actual === 'object' && actual !== null ? actual : {}),
|
||||
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
|
||||
}
|
||||
})
|
||||
|
||||
const mockUseInfiniteDatasets = vi.fn()
|
||||
vi.mock('@/service/knowledge/use-dataset', () => ({
|
||||
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-knowledge', () => ({
|
||||
useKnowledge: () => ({
|
||||
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
|
||||
}),
|
||||
}))
|
||||
|
||||
const baseProps = {
|
||||
isShow: true,
|
||||
onClose: vi.fn(),
|
||||
selectedIds: [] as string[],
|
||||
onSelect: vi.fn(),
|
||||
}
|
||||
|
||||
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
|
||||
id: 'dataset-id',
|
||||
name: 'Dataset Name',
|
||||
provider: 'internal',
|
||||
icon_info: {
|
||||
icon_type: 'emoji',
|
||||
icon: '💾',
|
||||
icon_background: '#fff',
|
||||
icon_url: '',
|
||||
},
|
||||
embedding_available: true,
|
||||
is_multimodal: false,
|
||||
description: '',
|
||||
permission: DatasetPermission.allTeamMembers,
|
||||
indexing_technique: IndexingType.ECONOMICAL,
|
||||
retrieval_model_dict: {
|
||||
search_method: RETRIEVE_METHOD.fullText,
|
||||
top_k: 5,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_model_name: '',
|
||||
reranking_provider_name: '',
|
||||
},
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
...overrides,
|
||||
} as DataSet)
|
||||
|
||||
describe('SelectDataSet', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders dataset entries, allows selection, and fires onSelect', async () => {
|
||||
const datasetOne = makeDataset({
|
||||
id: 'set-1',
|
||||
name: 'Dataset One',
|
||||
is_multimodal: true,
|
||||
indexing_technique: IndexingType.ECONOMICAL,
|
||||
})
|
||||
const datasetTwo = makeDataset({
|
||||
id: 'set-2',
|
||||
name: 'Hidden Dataset',
|
||||
embedding_available: false,
|
||||
provider: 'external',
|
||||
})
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
const onSelect = vi.fn()
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('Dataset One')).toBeInTheDocument()
|
||||
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText('Dataset One'))
|
||||
})
|
||||
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
|
||||
|
||||
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
|
||||
await act(async () => {
|
||||
fireEvent.click(addButton)
|
||||
})
|
||||
expect(onSelect).toHaveBeenCalledWith([datasetOne])
|
||||
})
|
||||
|
||||
it('shows empty state when no datasets are available and disables add', async () => {
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
|
||||
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
import type { IPromptValuePanelProps } from './index'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
|
||||
import PromptValuePanel from './index'
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
|
||||
__esModule: true,
|
||||
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
|
||||
<button type="button" onClick={onFeatureBarClick}>
|
||||
feature bar
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
const mockSetShowAppConfigureFeaturesModal = vi.fn()
|
||||
const mockUseStore = vi.mocked(useStore)
|
||||
const mockSetInputs = vi.fn()
|
||||
const mockOnSend = vi.fn()
|
||||
|
||||
const promptVariables = [
|
||||
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
|
||||
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
|
||||
] as const
|
||||
|
||||
const baseContextValue: any = {
|
||||
modelModeType: ModelModeType.completion,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: 'prompt template',
|
||||
prompt_variables: promptVariables,
|
||||
},
|
||||
},
|
||||
setInputs: mockSetInputs,
|
||||
mode: AppModeEnum.COMPLETION,
|
||||
isAdvancedMode: false,
|
||||
completionPromptConfig: {
|
||||
prompt: { text: 'completion' },
|
||||
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
|
||||
},
|
||||
chatPromptConfig: { prompt: [] },
|
||||
} as any
|
||||
|
||||
const defaultProps: IPromptValuePanelProps = {
|
||||
appType: AppModeEnum.COMPLETION,
|
||||
onSend: mockOnSend,
|
||||
inputs: { textVar: 'initial', boolVar: false },
|
||||
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
|
||||
onVisionFilesChange: vi.fn(),
|
||||
}
|
||||
|
||||
const renderPanel = (options: {
|
||||
context?: Partial<typeof baseContextValue>
|
||||
props?: Partial<IPromptValuePanelProps>
|
||||
} = {}) => {
|
||||
const contextValue = { ...baseContextValue, ...options.context }
|
||||
const props = { ...defaultProps, ...options.props }
|
||||
return render(
|
||||
<ConfigContext.Provider value={contextValue}>
|
||||
<PromptValuePanel {...props} />
|
||||
</ConfigContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('PromptValuePanel', () => {
|
||||
beforeEach(() => {
|
||||
mockUseStore.mockImplementation(selector => selector({
|
||||
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
|
||||
appSidebarExpand: '',
|
||||
currentLogModalActiveTab: 'prompt',
|
||||
showPromptLogModal: false,
|
||||
showAgentLogModal: false,
|
||||
setShowPromptLogModal: vi.fn(),
|
||||
setShowAgentLogModal: vi.fn(),
|
||||
showMessageLogModal: false,
|
||||
showAppConfigureFeaturesModal: false,
|
||||
} as any))
|
||||
mockSetInputs.mockClear()
|
||||
mockOnSend.mockClear()
|
||||
mockSetShowAppConfigureFeaturesModal.mockClear()
|
||||
})
|
||||
|
||||
it('updates inputs, clears values, and triggers run when ready', async () => {
|
||||
renderPanel()
|
||||
|
||||
const textInput = screen.getByPlaceholderText('Text Var')
|
||||
fireEvent.change(textInput, { target: { value: 'updated' } })
|
||||
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
|
||||
|
||||
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
|
||||
fireEvent.click(clearButton)
|
||||
|
||||
expect(mockSetInputs).toHaveBeenLastCalledWith({
|
||||
textVar: '',
|
||||
boolVar: '',
|
||||
})
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||
expect(runButton).not.toBeDisabled()
|
||||
fireEvent.click(runButton)
|
||||
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
|
||||
})
|
||||
|
||||
it('disables run when mode is not completion', () => {
|
||||
renderPanel({
|
||||
context: {
|
||||
mode: AppModeEnum.CHAT,
|
||||
},
|
||||
props: {
|
||||
appType: AppModeEnum.CHAT,
|
||||
},
|
||||
})
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||
expect(runButton).toBeDisabled()
|
||||
fireEvent.click(runButton)
|
||||
expect(mockOnSend).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import type { PromptVariable } from '@/models/debug'
|
||||
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { replaceStringWithValues } from './utils'
|
||||
|
||||
const promptVariables: PromptVariable[] = [
|
||||
{ key: 'user', name: 'User', type: 'string' },
|
||||
{ key: 'topic', name: 'Topic', type: 'string' },
|
||||
]
|
||||
|
||||
describe('replaceStringWithValues', () => {
|
||||
it('should replace placeholders when inputs have values', () => {
|
||||
const template = 'Hello {{user}} talking about {{topic}}'
|
||||
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
|
||||
expect(result).toBe('Hello Alice talking about cats')
|
||||
})
|
||||
|
||||
it('should use prompt variable name when value is missing', () => {
|
||||
const template = 'Hi {{user}} from {{topic}}'
|
||||
const result = replaceStringWithValues(template, promptVariables, {})
|
||||
expect(result).toBe('Hi {{User}} from {{Topic}}')
|
||||
})
|
||||
|
||||
it('should leave placeholder untouched when no variable is defined', () => {
|
||||
const template = 'Unknown {{missing}} placeholder'
|
||||
const result = replaceStringWithValues(template, promptVariables, {})
|
||||
expect(result).toBe('Unknown {{missing}} placeholder')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import CreateAppModal from './index'
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounceFn: (fn: (...args: any[]) => any) => {
|
||||
const run = (...args: any[]) => fn(...args)
|
||||
const cancel = vi.fn()
|
||||
const flush = vi.fn()
|
||||
return { run, cancel, flush }
|
||||
},
|
||||
useKeyPress: vi.fn(),
|
||||
useHover: () => false,
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
createApp: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/utils/app-redirection', () => ({
|
||||
getRedirection: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => () => '/guides',
|
||||
}))
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({ theme: 'light' }),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockUseRouter = vi.mocked(useRouter)
|
||||
const mockPush = vi.fn()
|
||||
const mockCreateApp = vi.mocked(createApp)
|
||||
const mockTrackEvent = vi.mocked(trackEvent)
|
||||
const mockGetRedirection = vi.mocked(getRedirection)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
|
||||
const defaultPlanUsage = {
|
||||
buildApps: 0,
|
||||
teamMembers: 0,
|
||||
annotatedResponse: 0,
|
||||
documentsUploadQuota: 0,
|
||||
apiRateLimit: 0,
|
||||
triggerEvents: 0,
|
||||
vectorSpace: 0,
|
||||
}
|
||||
|
||||
const renderModal = () => {
|
||||
const onClose = vi.fn()
|
||||
const onSuccess = vi.fn()
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
return { onClose, onSuccess }
|
||||
}
|
||||
|
||||
describe('CreateAppModal', () => {
|
||||
const mockSetItem = vi.fn()
|
||||
const originalLocalStorage = window.localStorage
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseRouter.mockReturnValue({ push: mockPush } as any)
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: AppModeEnum.ADVANCED_CHAT,
|
||||
usage: defaultPlanUsage,
|
||||
total: { ...defaultPlanUsage, buildApps: 1 },
|
||||
reset: {},
|
||||
},
|
||||
enableBilling: true,
|
||||
} as any)
|
||||
mockUseAppContext.mockReturnValue({
|
||||
isCurrentWorkspaceEditor: true,
|
||||
} as any)
|
||||
mockSetItem.mockClear()
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: {
|
||||
setItem: mockSetItem,
|
||||
getItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
key: vi.fn(),
|
||||
length: 0,
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: originalLocalStorage,
|
||||
writable: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('creates an app, notifies success, and fires callbacks', async () => {
|
||||
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
|
||||
mockCreateApp.mockResolvedValue(mockApp as any)
|
||||
const { onClose, onSuccess } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||
name: 'My App',
|
||||
description: '',
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#FFEAD5',
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
}))
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
|
||||
app_mode: AppModeEnum.ADVANCED_CHAT,
|
||||
description: '',
|
||||
})
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
|
||||
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
|
||||
})
|
||||
|
||||
it('shows error toast when creation fails', async () => {
|
||||
mockCreateApp.mockRejectedValue(new Error('boom'))
|
||||
const { onClose } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||
expect(onClose).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -139,14 +139,14 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
|
|||
id: item.id,
|
||||
content: item.answer,
|
||||
agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files),
|
||||
feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback
|
||||
adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback
|
||||
feedback: item.feedbacks?.find(item => item.from_source === 'user'), // user feedback
|
||||
adminFeedback: item.feedbacks?.find(item => item.from_source === 'admin'), // admin feedback
|
||||
feedbackDisabled: false,
|
||||
isAnswer: true,
|
||||
message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))),
|
||||
log: [
|
||||
...item.message,
|
||||
...(item.message[item.message.length - 1]?.role !== 'assistant'
|
||||
...(item.message ?? []),
|
||||
...(item.message?.[item.message.length - 1]?.role !== 'assistant'
|
||||
? [
|
||||
{
|
||||
role: 'assistant',
|
||||
|
|
@ -165,7 +165,7 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
|
|||
more: {
|
||||
time: dayjs.unix(item.created_at).tz(timezone).format(format),
|
||||
tokens: item.answer_tokens + item.message_tokens,
|
||||
latency: item.provider_response_latency.toFixed(2),
|
||||
latency: (item.provider_response_latency ?? 0).toFixed(2),
|
||||
},
|
||||
citation: item.metadata?.retriever_resources,
|
||||
annotation: (() => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
import type { SiteInfo } from '@/models/share'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import * as React from 'react'
|
||||
|
||||
import { act } from 'react'
|
||||
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import Embedded from './index'
|
||||
|
||||
vi.mock('./style.module.css', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
option: 'option',
|
||||
active: 'active',
|
||||
iframeIcon: 'iframeIcon',
|
||||
scriptsIcon: 'scriptsIcon',
|
||||
chromePluginIcon: 'chromePluginIcon',
|
||||
pluginInstallIcon: 'pluginInstallIcon',
|
||||
},
|
||||
}))
|
||||
const mockThemeBuilder = {
|
||||
buildTheme: vi.fn(),
|
||||
theme: {
|
||||
primaryColor: '#123456',
|
||||
},
|
||||
}
|
||||
const mockUseAppContext = vi.fn(() => ({
|
||||
langGeniusVersionInfo: {
|
||||
current_env: 'PRODUCTION',
|
||||
current_version: '',
|
||||
latest_version: '',
|
||||
release_date: '',
|
||||
release_notes: '',
|
||||
version: '',
|
||||
can_auto_update: false,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('copy-to-clipboard', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
|
||||
useThemeContext: () => mockThemeBuilder,
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => mockUseAppContext(),
|
||||
}))
|
||||
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
const mockedCopy = vi.mocked(copy)
|
||||
|
||||
const siteInfo: SiteInfo = {
|
||||
title: 'test site',
|
||||
chat_color_theme: '#000000',
|
||||
chat_color_theme_inverted: false,
|
||||
}
|
||||
|
||||
const baseProps = {
|
||||
isShow: true,
|
||||
siteInfo,
|
||||
onClose: vi.fn(),
|
||||
appBaseUrl: 'https://app.example.com',
|
||||
accessToken: 'token',
|
||||
className: 'custom-modal',
|
||||
}
|
||||
|
||||
const getCopyButton = () => {
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const actionButton = buttons.find(button => button.className.includes('action-btn'))
|
||||
expect(actionButton).toBeDefined()
|
||||
return actionButton!
|
||||
}
|
||||
|
||||
describe('Embedded', () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWindowOpen.mockClear()
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
mockWindowOpen.mockRestore()
|
||||
})
|
||||
|
||||
it('builds theme and copies iframe snippet', async () => {
|
||||
await act(async () => {
|
||||
render(<Embedded {...baseProps} />)
|
||||
})
|
||||
|
||||
const actionButton = getCopyButton()
|
||||
const innerDiv = actionButton.querySelector('div')
|
||||
act(() => {
|
||||
fireEvent.click(innerDiv ?? actionButton)
|
||||
})
|
||||
|
||||
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
|
||||
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
|
||||
})
|
||||
|
||||
it('opens chrome plugin store link when chrome option selected', async () => {
|
||||
await act(async () => {
|
||||
render(<Embedded {...baseProps} />)
|
||||
})
|
||||
|
||||
const optionButtons = document.body.querySelectorAll('[class*="option"]')
|
||||
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
act(() => {
|
||||
fireEvent.click(optionButtons[2])
|
||||
})
|
||||
|
||||
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
|
||||
act(() => {
|
||||
fireEvent.click(chromeText)
|
||||
})
|
||||
|
||||
expect(mockWindowOpen).toHaveBeenCalledWith(
|
||||
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
|
||||
'_blank',
|
||||
'noopener,noreferrer',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import type { ISavedItemsProps } from './index'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import SavedItems from './index'
|
||||
|
||||
vi.mock('copy-to-clipboard', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useParams: () => ({}),
|
||||
usePathname: () => '/',
|
||||
}))
|
||||
|
||||
const mockCopy = vi.mocked(copy)
|
||||
const toastNotifySpy = vi.spyOn(Toast, 'notify')
|
||||
|
||||
const baseProps: ISavedItemsProps = {
|
||||
list: [
|
||||
{ id: '1', answer: 'hello world' },
|
||||
],
|
||||
isShowTextToSpeech: true,
|
||||
onRemove: vi.fn(),
|
||||
onStartCreateContent: vi.fn(),
|
||||
}
|
||||
|
||||
describe('SavedItems', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toastNotifySpy.mockClear()
|
||||
})
|
||||
|
||||
it('renders saved answers with metadata and controls', () => {
|
||||
const { container } = render(<SavedItems {...baseProps} />)
|
||||
|
||||
const markdownElement = container.querySelector('.markdown-body')
|
||||
expect(markdownElement).toBeInTheDocument()
|
||||
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
|
||||
|
||||
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
})
|
||||
|
||||
it('copies content and notifies, and triggers remove callback', () => {
|
||||
const handleRemove = vi.fn()
|
||||
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
|
||||
|
||||
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
|
||||
const copyButton = actionButtons[1]
|
||||
const deleteButton = actionButtons[2]
|
||||
|
||||
fireEvent.click(copyButton)
|
||||
expect(mockCopy).toHaveBeenCalledWith('hello world')
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
|
||||
|
||||
fireEvent.click(deleteButton)
|
||||
expect(handleRemove).toHaveBeenCalledWith('1')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import NoData from './index'
|
||||
|
||||
describe('NoData', () => {
|
||||
it('renders title/description and calls callback when button clicked', () => {
|
||||
const handleStart = vi.fn()
|
||||
render(<NoData onStartCreateContent={handleStart} />)
|
||||
|
||||
const title = screen.getByText('share.generation.savedNoData.title')
|
||||
const description = screen.getByText('share.generation.savedNoData.description')
|
||||
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
|
||||
|
||||
expect(title).toBeInTheDocument()
|
||||
expect(description).toBeInTheDocument()
|
||||
expect(button).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(button)
|
||||
expect(handleStart).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Avatar from './index'
|
||||
|
||||
describe('Avatar', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Rendering tests - verify component renders correctly in different states
|
||||
describe('Rendering', () => {
|
||||
it('should render img element with correct alt and src when avatar URL is provided', () => {
|
||||
const avatarUrl = 'https://example.com/avatar.jpg'
|
||||
const props = { name: 'John Doe', avatar: avatarUrl }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const img = screen.getByRole('img', { name: 'John Doe' })
|
||||
expect(img).toBeInTheDocument()
|
||||
expect(img).toHaveAttribute('src', avatarUrl)
|
||||
})
|
||||
|
||||
it('should render fallback div with uppercase initial when avatar is null', () => {
|
||||
const props = { name: 'alice', avatar: null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Props tests - verify all props are applied correctly
|
||||
describe('Props', () => {
|
||||
describe('size prop', () => {
|
||||
it.each([
|
||||
{ size: undefined, expected: '30px', label: 'default (30px)' },
|
||||
{ size: 50, expected: '50px', label: 'custom (50px)' },
|
||||
])('should apply $label size to img element', ({ size, expected }) => {
|
||||
const props = { name: 'Test', avatar: 'https://example.com/avatar.jpg', size }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(screen.getByRole('img')).toHaveStyle({
|
||||
width: expected,
|
||||
height: expected,
|
||||
fontSize: expected,
|
||||
lineHeight: expected,
|
||||
})
|
||||
})
|
||||
|
||||
it('should apply size to fallback div when avatar is null', () => {
|
||||
const props = { name: 'Test', avatar: null, size: 40 }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveStyle({ width: '40px', height: '40px' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('className prop', () => {
|
||||
it('should merge className with default avatar classes on img', () => {
|
||||
const props = {
|
||||
name: 'Test',
|
||||
avatar: 'https://example.com/avatar.jpg',
|
||||
className: 'custom-class',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const img = screen.getByRole('img')
|
||||
expect(img).toHaveClass('custom-class')
|
||||
expect(img).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
|
||||
})
|
||||
|
||||
it('should merge className with default avatar classes on fallback div', () => {
|
||||
const props = {
|
||||
name: 'Test',
|
||||
avatar: null,
|
||||
className: 'my-custom-class',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveClass('my-custom-class')
|
||||
expect(outerDiv).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
|
||||
})
|
||||
})
|
||||
|
||||
describe('textClassName prop', () => {
|
||||
it('should apply textClassName to the initial text element', () => {
|
||||
const props = {
|
||||
name: 'Test',
|
||||
avatar: null,
|
||||
textClassName: 'custom-text-class',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
expect(textElement).toHaveClass('custom-text-class')
|
||||
expect(textElement).toHaveClass('scale-[0.4]', 'text-center', 'text-white')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// State Management tests - verify useState and useEffect behavior
|
||||
describe('State Management', () => {
|
||||
it('should switch to fallback when image fails to load', async () => {
|
||||
const props = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
render(<Avatar {...props} />)
|
||||
const img = screen.getByRole('img')
|
||||
|
||||
fireEvent.error(img)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should reset error state when avatar URL changes', async () => {
|
||||
const initialProps = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
const { rerender } = render(<Avatar {...initialProps} />)
|
||||
const img = screen.getByRole('img')
|
||||
|
||||
// First, trigger error
|
||||
fireEvent.error(img)
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
|
||||
rerender(<Avatar name="John" avatar="https://example.com/new-avatar.jpg" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('img')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.queryByText('J')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not reset error state if avatar becomes null', async () => {
|
||||
const initialProps = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
const { rerender } = render(<Avatar {...initialProps} />)
|
||||
|
||||
// Trigger error
|
||||
fireEvent.error(screen.getByRole('img'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
rerender(<Avatar name="John" avatar={null} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Event Handlers tests - verify onError callback behavior
|
||||
describe('Event Handlers', () => {
|
||||
it('should call onError with true when image fails to load', () => {
|
||||
const onErrorMock = vi.fn()
|
||||
const props = {
|
||||
name: 'John',
|
||||
avatar: 'https://example.com/broken.jpg',
|
||||
onError: onErrorMock,
|
||||
}
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
fireEvent.error(screen.getByRole('img'))
|
||||
|
||||
expect(onErrorMock).toHaveBeenCalledTimes(1)
|
||||
expect(onErrorMock).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should call onError with false when image loads successfully', () => {
|
||||
const onErrorMock = vi.fn()
|
||||
const props = {
|
||||
name: 'John',
|
||||
avatar: 'https://example.com/avatar.jpg',
|
||||
onError: onErrorMock,
|
||||
}
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
fireEvent.load(screen.getByRole('img'))
|
||||
|
||||
expect(onErrorMock).toHaveBeenCalledTimes(1)
|
||||
expect(onErrorMock).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should not throw when onError is not provided', async () => {
|
||||
const props = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(() => fireEvent.error(screen.getByRole('img'))).not.toThrow()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Edge Cases tests - verify handling of unusual inputs
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string name gracefully', () => {
|
||||
const props = { name: '', avatar: null }
|
||||
|
||||
const { container } = render(<Avatar {...props} />)
|
||||
|
||||
// Note: Using querySelector here because empty name produces no visible text,
|
||||
// making semantic queries (getByRole, getByText) impossible
|
||||
const textElement = container.querySelector('.text-white') as HTMLElement
|
||||
expect(textElement).toBeInTheDocument()
|
||||
expect(textElement.textContent).toBe('')
|
||||
})
|
||||
|
||||
it.each([
|
||||
{ name: '中文名', expected: '中', label: 'Chinese characters' },
|
||||
{ name: '123User', expected: '1', label: 'number' },
|
||||
])('should display first character when name starts with $label', ({ name, expected }) => {
|
||||
const props = { name, avatar: null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(screen.getByText(expected)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty string avatar as falsy value', () => {
|
||||
const props = { name: 'Test', avatar: '' as string | null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('T')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined className and textClassName', () => {
|
||||
const props = { name: 'Test', avatar: null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
|
||||
})
|
||||
|
||||
it.each([
|
||||
{ size: 0, expected: '0px', label: 'zero' },
|
||||
{ size: 1000, expected: '1000px', label: 'very large' },
|
||||
])('should handle $label size value', ({ size, expected }) => {
|
||||
const props = { name: 'Test', avatar: null, size }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveStyle({ width: expected, height: expected })
|
||||
})
|
||||
})
|
||||
|
||||
// Combined props tests - verify props work together correctly
|
||||
describe('Combined Props', () => {
|
||||
it('should apply all props correctly when used together', () => {
|
||||
const onErrorMock = vi.fn()
|
||||
const props = {
|
||||
name: 'Test User',
|
||||
avatar: 'https://example.com/avatar.jpg',
|
||||
size: 64,
|
||||
className: 'custom-avatar',
|
||||
onError: onErrorMock,
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const img = screen.getByRole('img')
|
||||
expect(img).toHaveAttribute('alt', 'Test User')
|
||||
expect(img).toHaveAttribute('src', 'https://example.com/avatar.jpg')
|
||||
expect(img).toHaveStyle({ width: '64px', height: '64px' })
|
||||
expect(img).toHaveClass('custom-avatar')
|
||||
|
||||
// Trigger load to verify onError callback
|
||||
fireEvent.load(img)
|
||||
expect(onErrorMock).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should apply all fallback props correctly when used together', () => {
|
||||
const props = {
|
||||
name: 'Fallback User',
|
||||
avatar: null,
|
||||
size: 48,
|
||||
className: 'fallback-custom',
|
||||
textClassName: 'custom-text-style',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('F')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveClass('fallback-custom')
|
||||
expect(outerDiv).toHaveStyle({ width: '48px', height: '48px' })
|
||||
expect(textElement).toHaveClass('custom-text-style')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { updateCurrentWorkspace } from '@/service/common'
|
||||
import CustomWebAppBrand from './index'
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
useToastContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/common', () => ({
|
||||
updateCurrentWorkspace: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/image-uploader/utils', () => ({
|
||||
imageUpload: vi.fn(),
|
||||
getImageUploadErrorMessage: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockUseToastContext = vi.mocked(useToastContext)
|
||||
const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
|
||||
const mockImageUpload = vi.mocked(imageUpload)
|
||||
const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage)
|
||||
|
||||
const defaultPlanUsage = {
|
||||
buildApps: 0,
|
||||
teamMembers: 0,
|
||||
annotatedResponse: 0,
|
||||
documentsUploadQuota: 0,
|
||||
apiRateLimit: 0,
|
||||
triggerEvents: 0,
|
||||
vectorSpace: 0,
|
||||
}
|
||||
|
||||
const renderComponent = () => render(<CustomWebAppBrand />)
|
||||
|
||||
describe('CustomWebAppBrand', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseToastContext.mockReturnValue({ notify: mockNotify } as any)
|
||||
mockUpdateCurrentWorkspace.mockResolvedValue({} as any)
|
||||
mockUseAppContext.mockReturnValue({
|
||||
currentWorkspace: {
|
||||
custom_config: {
|
||||
replace_webapp_logo: 'https://example.com/replace.png',
|
||||
remove_webapp_brand: false,
|
||||
},
|
||||
},
|
||||
mutateCurrentWorkspace: vi.fn(),
|
||||
isCurrentWorkspaceManager: true,
|
||||
} as any)
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: Plan.professional,
|
||||
usage: defaultPlanUsage,
|
||||
total: defaultPlanUsage,
|
||||
reset: {},
|
||||
},
|
||||
enableBilling: false,
|
||||
} as any)
|
||||
const systemFeaturesState = {
|
||||
branding: {
|
||||
enabled: true,
|
||||
workspace_logo: 'https://example.com/workspace-logo.png',
|
||||
},
|
||||
}
|
||||
mockUseGlobalPublicStore.mockImplementation(selector => selector ? selector({ systemFeatures: systemFeaturesState } as any) : { systemFeatures: systemFeaturesState })
|
||||
mockGetImageUploadErrorMessage.mockReturnValue('upload error')
|
||||
})
|
||||
|
||||
it('disables upload controls when the user cannot manage the workspace', () => {
|
||||
mockUseAppContext.mockReturnValue({
|
||||
currentWorkspace: {
|
||||
custom_config: {
|
||||
replace_webapp_logo: '',
|
||||
remove_webapp_brand: false,
|
||||
},
|
||||
},
|
||||
mutateCurrentWorkspace: vi.fn(),
|
||||
isCurrentWorkspaceManager: false,
|
||||
} as any)
|
||||
|
||||
const { container } = renderComponent()
|
||||
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
||||
expect(fileInput).toBeDisabled()
|
||||
})
|
||||
|
||||
it('toggles remove brand switch and calls the backend + mutate', async () => {
|
||||
const mutateMock = vi.fn()
|
||||
mockUseAppContext.mockReturnValue({
|
||||
currentWorkspace: {
|
||||
custom_config: {
|
||||
replace_webapp_logo: '',
|
||||
remove_webapp_brand: false,
|
||||
},
|
||||
},
|
||||
mutateCurrentWorkspace: mutateMock,
|
||||
isCurrentWorkspaceManager: true,
|
||||
} as any)
|
||||
|
||||
renderComponent()
|
||||
const switchInput = screen.getByRole('switch')
|
||||
fireEvent.click(switchInput)
|
||||
|
||||
await waitFor(() => expect(mockUpdateCurrentWorkspace).toHaveBeenCalledWith({
|
||||
url: '/workspaces/custom-config',
|
||||
body: { remove_webapp_brand: true },
|
||||
}))
|
||||
await waitFor(() => expect(mutateMock).toHaveBeenCalled())
|
||||
})
|
||||
|
||||
it('shows cancel/apply buttons after successful upload and cancels properly', async () => {
|
||||
mockImageUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }) => {
|
||||
onProgressCallback(50)
|
||||
onSuccessCallback({ id: 'new-logo' })
|
||||
})
|
||||
|
||||
const { container } = renderComponent()
|
||||
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
||||
const testFile = new File(['content'], 'logo.png', { type: 'image/png' })
|
||||
fireEvent.change(fileInput, { target: { files: [testFile] } })
|
||||
|
||||
await waitFor(() => expect(mockImageUpload).toHaveBeenCalled())
|
||||
await waitFor(() => screen.getByRole('button', { name: 'custom.apply' }))
|
||||
|
||||
const cancelButton = screen.getByRole('button', { name: 'common.operation.cancel' })
|
||||
fireEvent.click(cancelButton)
|
||||
|
||||
await waitFor(() => expect(screen.queryByRole('button', { name: 'custom.apply' })).toBeNull())
|
||||
})
|
||||
})
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import type { FC } from 'react'
|
||||
import { RiArchive2Line, RiCheckboxCircleLine, RiCloseCircleLine, RiDeleteBinLine, RiDraftLine } from '@remixicon/react'
|
||||
import { RiArchive2Line, RiCheckboxCircleLine, RiCloseCircleLine, RiDeleteBinLine, RiDraftLine, RiRefreshLine } from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
|
@ -17,6 +17,7 @@ type IBatchActionProps = {
|
|||
onBatchDelete: () => Promise<void>
|
||||
onArchive?: () => void
|
||||
onEditMetadata?: () => void
|
||||
onBatchReIndex?: () => void
|
||||
onCancel: () => void
|
||||
}
|
||||
|
||||
|
|
@ -28,6 +29,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
|||
onArchive,
|
||||
onBatchDelete,
|
||||
onEditMetadata,
|
||||
onBatchReIndex,
|
||||
onCancel,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
|
@ -91,6 +93,16 @@ const BatchAction: FC<IBatchActionProps> = ({
|
|||
<span className="px-0.5">{t(`${i18nPrefix}.archive`)}</span>
|
||||
</Button>
|
||||
)}
|
||||
{onBatchReIndex && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="gap-x-0.5 px-3"
|
||||
onClick={onBatchReIndex}
|
||||
>
|
||||
<RiRefreshLine className="size-4" />
|
||||
<span className="px-0.5">{t(`${i18nPrefix}.reIndex`)}</span>
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="ghost"
|
||||
destructive
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '
|
|||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
import { useDocumentArchive, useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document'
|
||||
import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
|
|
@ -220,6 +220,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
const { mutateAsync: enableDocument } = useDocumentEnable()
|
||||
const { mutateAsync: disableDocument } = useDocumentDisable()
|
||||
const { mutateAsync: deleteDocument } = useDocumentDelete()
|
||||
const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex()
|
||||
|
||||
const handleAction = (actionName: DocumentActionType) => {
|
||||
return async () => {
|
||||
|
|
@ -250,6 +251,22 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
}
|
||||
}
|
||||
|
||||
const handleBatchReIndex = async () => {
|
||||
const [e] = await asyncRunSafe<CommonResponse>(retryIndexDocument({ datasetId, documentIds: selectedIds }))
|
||||
if (!e) {
|
||||
onSelectedIdChange([])
|
||||
Toast.notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
onUpdate()
|
||||
}
|
||||
else {
|
||||
Toast.notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
}
|
||||
}
|
||||
|
||||
const hasErrorDocumentsSelected = useMemo(() => {
|
||||
return localDocs.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error')
|
||||
}, [localDocs, selectedIds])
|
||||
|
||||
const getFileExtension = useCallback((fileName: string): string => {
|
||||
if (!fileName)
|
||||
return ''
|
||||
|
|
@ -447,6 +464,7 @@ const DocumentList: FC<IDocumentListProps> = ({
|
|||
onBatchDisable={handleAction(DocumentActionType.disable)}
|
||||
onBatchDelete={handleAction(DocumentActionType.delete)}
|
||||
onEditMetadata={showEditModal}
|
||||
onBatchReIndex={hasErrorDocumentsSelected ? handleBatchReIndex : undefined}
|
||||
onCancel={() => {
|
||||
onSelectedIdChange([])
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ import Records from './components/records'
|
|||
import ResultItem from './components/result-item'
|
||||
import ResultItemExternal from './components/result-item-external'
|
||||
import ModifyRetrievalModal from './modify-retrieval-modal'
|
||||
import s from './style.module.css'
|
||||
|
||||
const limit = 10
|
||||
|
||||
|
|
@ -115,8 +114,8 @@ const HitTestingPage: FC<Props> = ({ datasetId }: Props) => {
|
|||
}, [isMobile, setShowRightPanel])
|
||||
|
||||
return (
|
||||
<div className={s.container}>
|
||||
<div className="flex flex-col px-6 py-3">
|
||||
<div className="relative flex h-full w-full gap-x-6 overflow-y-auto pl-6">
|
||||
<div className="flex min-w-0 flex-1 flex-col py-3">
|
||||
<div className="mb-4 flex flex-col justify-center">
|
||||
<h1 className="text-base font-semibold text-text-primary">{t('datasetHitTesting.title')}</h1>
|
||||
<p className="mt-0.5 text-[13px] font-normal leading-4 text-text-tertiary">{t('datasetHitTesting.desc')}</p>
|
||||
|
|
@ -161,7 +160,7 @@ const HitTestingPage: FC<Props> = ({ datasetId }: Props) => {
|
|||
onClose={hideRightPanel}
|
||||
footer={null}
|
||||
>
|
||||
<div className="flex flex-col pt-3">
|
||||
<div className="flex min-w-0 flex-1 flex-col pt-3">
|
||||
{isRetrievalLoading
|
||||
? (
|
||||
<div className="flex h-full flex-col rounded-tl-2xl bg-background-body px-4 py-3">
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
.container {
|
||||
@apply flex h-full w-full relative overflow-y-auto;
|
||||
}
|
||||
|
||||
.container>div {
|
||||
@apply flex-1 h-full;
|
||||
}
|
||||
|
||||
.commonIcon {
|
||||
@apply w-3.5 h-3.5 inline-block align-middle;
|
||||
background-repeat: no-repeat;
|
||||
background-position: center center;
|
||||
background-size: contain;
|
||||
}
|
||||
|
||||
.app_icon {
|
||||
background-image: url(./assets/grid.svg);
|
||||
}
|
||||
|
||||
.hit_testing_icon {
|
||||
background-image: url(../documents/assets/target.svg);
|
||||
}
|
||||
|
||||
.plugin_icon {
|
||||
background-image: url(./assets/plugin.svg);
|
||||
}
|
||||
|
||||
.cardWrapper {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(284px, auto));
|
||||
grid-gap: 16px;
|
||||
grid-auto-rows: 216px;
|
||||
}
|
||||
|
||||
.clockWrapper {
|
||||
border: 0.5px solid #eaecf5;
|
||||
@apply rounded-lg w-11 h-11 flex justify-center items-center;
|
||||
}
|
||||
|
||||
.clockIcon {
|
||||
mask-image: url(./assets/clock.svg);
|
||||
@apply bg-gray-500;
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
'use client'
|
||||
|
||||
import { TanStackDevtools } from '@tanstack/react-devtools'
|
||||
import { formDevtoolsPlugin } from '@tanstack/react-form-devtools'
|
||||
import { ReactQueryDevtoolsPanel } from '@tanstack/react-query-devtools'
|
||||
import * as React from 'react'
|
||||
|
||||
export function TanStackDevtoolsWrapper() {
|
||||
return (
|
||||
<TanStackDevtools
|
||||
plugins={[
|
||||
// Query Devtools (Official Plugin)
|
||||
{
|
||||
name: 'React Query',
|
||||
render: () => <ReactQueryDevtoolsPanel />,
|
||||
},
|
||||
|
||||
// Form Devtools (Official Plugin)
|
||||
formDevtoolsPlugin(),
|
||||
]}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
|
@ -3,14 +3,14 @@
|
|||
import * as Sentry from '@sentry/react'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
const isDevelopment = process.env.NODE_ENV === 'development'
|
||||
import { IS_DEV } from '@/config'
|
||||
|
||||
const SentryInitializer = ({
|
||||
children,
|
||||
}: { children: React.ReactElement }) => {
|
||||
useEffect(() => {
|
||||
const SENTRY_DSN = document?.body?.getAttribute('data-public-sentry-dsn')
|
||||
if (!isDevelopment && SENTRY_DSN) {
|
||||
if (!IS_DEV && SENTRY_DSN) {
|
||||
Sentry.init({
|
||||
dsn: SENTRY_DSN,
|
||||
integrations: [
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ const EditCustomCollectionModal: FC<Props> = ({
|
|||
|
||||
const [editFirst, setEditFirst] = useState(!isAdd)
|
||||
const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || [])
|
||||
const [labels, setLabels] = useState<string[]>(payload?.labels || [])
|
||||
const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd
|
||||
? {
|
||||
provider: '',
|
||||
|
|
@ -67,6 +68,15 @@ const EditCustomCollectionModal: FC<Props> = ({
|
|||
|
||||
const originalProvider = isEdit ? payload.provider : ''
|
||||
|
||||
// Sync customCollection state when payload changes
|
||||
useEffect(() => {
|
||||
if (isEdit) {
|
||||
setCustomCollection(payload)
|
||||
setParamsSchemas(payload.tools || [])
|
||||
setLabels(payload.labels || [])
|
||||
}
|
||||
}, [isEdit, payload])
|
||||
|
||||
const [showEmojiPicker, setShowEmojiPicker] = useState(false)
|
||||
const emoji = customCollection.icon
|
||||
const setEmoji = (emoji: Emoji) => {
|
||||
|
|
@ -124,7 +134,6 @@ const EditCustomCollectionModal: FC<Props> = ({
|
|||
const [currTool, setCurrTool] = useState<CustomParamSchema | null>(null)
|
||||
const [isShowTestApi, setIsShowTestApi] = useState(false)
|
||||
|
||||
const [labels, setLabels] = useState<string[]>(payload?.labels || [])
|
||||
const handleLabelSelect = (value: string[]) => {
|
||||
setLabels(value)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,9 +100,28 @@ const ProviderDetail = ({
|
|||
const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false)
|
||||
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
|
||||
const [deleteAction, setDeleteAction] = useState('')
|
||||
|
||||
const getCustomProvider = useCallback(async () => {
|
||||
setIsDetailLoading(true)
|
||||
const res = await fetchCustomCollection(collection.name)
|
||||
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
|
||||
if (res.credentials.api_key_value)
|
||||
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
|
||||
}
|
||||
setCustomCollection({
|
||||
...res,
|
||||
labels: collection.labels,
|
||||
provider: collection.name,
|
||||
})
|
||||
setIsDetailLoading(false)
|
||||
}, [collection.labels, collection.name])
|
||||
|
||||
const doUpdateCustomToolCollection = async (data: CustomCollectionBackend) => {
|
||||
await updateCustomCollection(data)
|
||||
onRefreshData()
|
||||
await getCustomProvider()
|
||||
// Use fresh data from form submission to avoid race condition with collection.labels
|
||||
setCustomCollection(prev => prev ? { ...prev, labels: data.labels } : null)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('common.api.actionSuccess'),
|
||||
|
|
@ -118,20 +137,6 @@ const ProviderDetail = ({
|
|||
})
|
||||
setIsShowEditCustomCollectionModal(false)
|
||||
}
|
||||
const getCustomProvider = useCallback(async () => {
|
||||
setIsDetailLoading(true)
|
||||
const res = await fetchCustomCollection(collection.name)
|
||||
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
|
||||
if (res.credentials.api_key_value)
|
||||
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
|
||||
}
|
||||
setCustomCollection({
|
||||
...res,
|
||||
labels: collection.labels,
|
||||
provider: collection.name,
|
||||
})
|
||||
setIsDetailLoading(false)
|
||||
}, [collection.labels, collection.name])
|
||||
// workflow provider
|
||||
const [isShowEditWorkflowToolModal, setIsShowEditWorkflowToolModal] = useState(false)
|
||||
const getWorkflowToolProvider = useCallback(async () => {
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ const VersionHistoryButton: FC<VersionHistoryButtonProps> = ({
|
|||
>
|
||||
<Button
|
||||
className={cn(
|
||||
'p-2 rounded-lg border border-transparent',
|
||||
'rounded-lg border border-transparent p-2',
|
||||
theme === 'dark' && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
onClick={handleViewVersionHistory}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import ReactFlow, {
|
|||
useReactFlow,
|
||||
useStoreApi,
|
||||
} from 'reactflow'
|
||||
import { IS_DEV } from '@/config'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import {
|
||||
useAllBuiltInTools,
|
||||
|
|
@ -361,7 +362,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
|||
}
|
||||
}, [schemaTypeDefinitions, fetchInspectVars, isLoadedVars, vars, customTools, buildInTools, workflowTools, mcpTools, dataSourceList])
|
||||
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
if (IS_DEV) {
|
||||
store.getState().onError = (code, message) => {
|
||||
if (code === '002')
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
'use client'
|
||||
import BaseForm from '../components/base/form/form-scenarios/base'
|
||||
import { BaseFieldType } from '../components/base/form/form-scenarios/base/types'
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="flex h-screen w-full items-center justify-center p-20">
|
||||
<div className="w-[400px] rounded-lg border border-components-panel-border bg-components-panel-bg">
|
||||
<BaseForm
|
||||
initialData={{
|
||||
type: 'option_1',
|
||||
variable: 'test',
|
||||
label: 'Test',
|
||||
maxLength: 48,
|
||||
required: true,
|
||||
}}
|
||||
configurations={[
|
||||
{
|
||||
type: BaseFieldType.textInput,
|
||||
variable: 'variable',
|
||||
label: 'Variable',
|
||||
required: true,
|
||||
showConditions: [],
|
||||
},
|
||||
{
|
||||
type: BaseFieldType.textInput,
|
||||
variable: 'label',
|
||||
label: 'Label',
|
||||
required: true,
|
||||
showConditions: [],
|
||||
},
|
||||
{
|
||||
type: BaseFieldType.numberInput,
|
||||
variable: 'maxLength',
|
||||
label: 'Max Length',
|
||||
required: true,
|
||||
showConditions: [],
|
||||
max: 100,
|
||||
min: 1,
|
||||
},
|
||||
{
|
||||
type: BaseFieldType.checkbox,
|
||||
variable: 'required',
|
||||
label: 'Required',
|
||||
required: true,
|
||||
showConditions: [],
|
||||
},
|
||||
{
|
||||
type: BaseFieldType.select,
|
||||
variable: 'type',
|
||||
label: 'Type',
|
||||
required: true,
|
||||
showConditions: [],
|
||||
options: [
|
||||
{ label: 'Option 1', value: 'option_1' },
|
||||
{ label: 'Option 2', value: 'option_2' },
|
||||
{ label: 'Option 3', value: 'option_3' },
|
||||
],
|
||||
},
|
||||
]}
|
||||
onSubmit={(value) => {
|
||||
console.log('onSubmit', value)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import type { Viewport } from 'next'
|
||||
import { ThemeProvider } from 'next-themes'
|
||||
import dynamic from 'next/dynamic'
|
||||
import { Instrument_Serif } from 'next/font/google'
|
||||
import { IS_DEV } from '@/config'
|
||||
import GlobalPublicStoreProvider from '@/context/global-public-context'
|
||||
import { TanstackQueryInitializer } from '@/context/query-client'
|
||||
import { getLocaleOnServer } from '@/i18n-config/server'
|
||||
|
|
@ -8,12 +10,15 @@ import { DatasetAttr } from '@/types/feature'
|
|||
import { cn } from '@/utils/classnames'
|
||||
import BrowserInitializer from './components/browser-initializer'
|
||||
import I18nServer from './components/i18n-server'
|
||||
import { ReactScan } from './components/react-scan'
|
||||
import SentryInitializer from './components/sentry-initializer'
|
||||
import RoutePrefixHandle from './routePrefixHandle'
|
||||
import './styles/globals.css'
|
||||
import './styles/markdown.scss'
|
||||
|
||||
const ReactScan = IS_DEV
|
||||
? dynamic(() => import('./components/react-scan').then(m => m.ReactScan), { ssr: false })
|
||||
: () => null
|
||||
|
||||
export const viewport: Viewport = {
|
||||
width: 'device-width',
|
||||
initialScale: 1,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,14 @@
|
|||
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { ReactQueryDevtools } from '@tanstack/react-query-devtools'
|
||||
import { lazy, Suspense } from 'react'
|
||||
import { IS_DEV } from '@/config'
|
||||
|
||||
const TanStackDevtoolsWrapper = lazy(() =>
|
||||
import('@/app/components/devtools').then(module => ({
|
||||
default: module.TanStackDevtoolsWrapper,
|
||||
})),
|
||||
)
|
||||
|
||||
const STALE_TIME = 1000 * 60 * 30 // 30 minutes
|
||||
|
||||
|
|
@ -19,7 +26,11 @@ export const TanstackQueryInitializer: FC<PropsWithChildren> = (props) => {
|
|||
return (
|
||||
<QueryClientProvider client={client}>
|
||||
{children}
|
||||
<ReactQueryDevtools initialIsOpen={false} />
|
||||
{IS_DEV && (
|
||||
<Suspense fallback={null}>
|
||||
<TanStackDevtoolsWrapper />
|
||||
</Suspense>
|
||||
)}
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -170,6 +170,7 @@ const translation = {
|
|||
enable: 'Enable',
|
||||
disable: 'Disable',
|
||||
archive: 'Archive',
|
||||
reIndex: 'Re-index',
|
||||
delete: 'Delete',
|
||||
cancel: 'Cancel',
|
||||
},
|
||||
|
|
|
|||
|
|
@ -170,6 +170,7 @@ const translation = {
|
|||
enable: '启用',
|
||||
disable: '禁用',
|
||||
archive: '归档',
|
||||
reIndex: '重新索引',
|
||||
delete: '删除',
|
||||
cancel: '取消',
|
||||
},
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"type": "module",
|
||||
"version": "1.11.2",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.26.1+sha512.664074abc367d2c9324fdc18037097ce0a8f126034160f709928e9e9f95d98714347044e5c3164d65bd5da6c59c6be362b107546292a8eecb7999196e5ce58fa",
|
||||
"packageManager": "pnpm@10.26.2+sha512.0e308ff2005fc7410366f154f625f6631ab2b16b1d2e70238444dd6ae9d630a8482d92a451144debc492416896ed16f7b114a86ec68b8404b2443869e68ffda6",
|
||||
"engines": {
|
||||
"node": ">=v22.11.0"
|
||||
},
|
||||
|
|
@ -71,7 +71,6 @@
|
|||
"@tailwindcss/typography": "^0.5.19",
|
||||
"@tanstack/react-form": "^1.23.7",
|
||||
"@tanstack/react-query": "^5.90.5",
|
||||
"@tanstack/react-query-devtools": "^5.90.2",
|
||||
"abcjs": "^6.5.2",
|
||||
"ahooks": "^3.9.5",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
|
|
@ -134,7 +133,7 @@
|
|||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"scheduler": "^0.26.0",
|
||||
"scheduler": "^0.27.0",
|
||||
"semver": "^7.7.3",
|
||||
"sharp": "^0.33.5",
|
||||
"sortablejs": "^1.15.6",
|
||||
|
|
@ -164,6 +163,9 @@
|
|||
"@storybook/addon-themes": "9.1.13",
|
||||
"@storybook/nextjs": "9.1.13",
|
||||
"@storybook/react": "9.1.13",
|
||||
"@tanstack/react-devtools": "^0.9.0",
|
||||
"@tanstack/react-form-devtools": "^0.2.9",
|
||||
"@tanstack/react-query-devtools": "^5.90.2",
|
||||
"@testing-library/dom": "^10.4.1",
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
|
|
|
|||
|
|
@ -129,9 +129,6 @@ importers:
|
|||
'@tanstack/react-query':
|
||||
specifier: ^5.90.5
|
||||
version: 5.90.12(react@19.2.3)
|
||||
'@tanstack/react-query-devtools':
|
||||
specifier: ^5.90.2
|
||||
version: 5.91.1(@tanstack/react-query@5.90.12(react@19.2.3))(react@19.2.3)
|
||||
abcjs:
|
||||
specifier: ^6.5.2
|
||||
version: 6.5.2
|
||||
|
|
@ -319,8 +316,8 @@ importers:
|
|||
specifier: ^6.0.0
|
||||
version: 6.0.0
|
||||
scheduler:
|
||||
specifier: ^0.26.0
|
||||
version: 0.26.0
|
||||
specifier: ^0.27.0
|
||||
version: 0.27.0
|
||||
semver:
|
||||
specifier: ^7.7.3
|
||||
version: 7.7.3
|
||||
|
|
@ -341,7 +338,7 @@ importers:
|
|||
version: 7.0.19
|
||||
use-context-selector:
|
||||
specifier: ^2.0.0
|
||||
version: 2.0.0(react@19.2.3)(scheduler@0.26.0)
|
||||
version: 2.0.0(react@19.2.3)(scheduler@0.27.0)
|
||||
uuid:
|
||||
specifier: ^10.0.0
|
||||
version: 10.0.0
|
||||
|
|
@ -403,6 +400,15 @@ importers:
|
|||
'@storybook/react':
|
||||
specifier: 9.1.13
|
||||
version: 9.1.13(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)
|
||||
'@tanstack/react-devtools':
|
||||
specifier: ^0.9.0
|
||||
version: 0.9.0(@types/react-dom@19.2.3(@types/react@19.2.7))(@types/react@19.2.7)(csstype@3.2.3)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(solid-js@1.9.10)
|
||||
'@tanstack/react-form-devtools':
|
||||
specifier: ^0.2.9
|
||||
version: 0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
|
||||
'@tanstack/react-query-devtools':
|
||||
specifier: ^5.90.2
|
||||
version: 5.91.1(@tanstack/react-query@5.90.12(react@19.2.3))(react@19.2.3)
|
||||
'@testing-library/dom':
|
||||
specifier: ^10.4.1
|
||||
version: 10.4.1
|
||||
|
|
@ -3169,6 +3175,36 @@ packages:
|
|||
resolution: {integrity: sha512-t09vSN3MdfsyCHoFcTRCH/iUtG7OJ0CsjzB8cjAmKc/va/kIgeDI/TxsigdncE/4be734m0cvIYwNaV4i2XqAw==}
|
||||
engines: {node: '>=10'}
|
||||
|
||||
'@solid-primitives/event-listener@2.4.3':
|
||||
resolution: {integrity: sha512-h4VqkYFv6Gf+L7SQj+Y6puigL/5DIi7x5q07VZET7AWcS+9/G3WfIE9WheniHWJs51OEkRB43w6lDys5YeFceg==}
|
||||
peerDependencies:
|
||||
solid-js: ^1.6.12
|
||||
|
||||
'@solid-primitives/keyboard@1.3.3':
|
||||
resolution: {integrity: sha512-9dQHTTgLBqyAI7aavtO+HnpTVJgWQA1ghBSrmLtMu1SMxLPDuLfuNr+Tk5udb4AL4Ojg7h9JrKOGEEDqsJXWJA==}
|
||||
peerDependencies:
|
||||
solid-js: ^1.6.12
|
||||
|
||||
'@solid-primitives/resize-observer@2.1.3':
|
||||
resolution: {integrity: sha512-zBLje5E06TgOg93S7rGPldmhDnouNGhvfZVKOp+oG2XU8snA+GoCSSCz1M+jpNAg5Ek2EakU5UVQqL152WmdXQ==}
|
||||
peerDependencies:
|
||||
solid-js: ^1.6.12
|
||||
|
||||
'@solid-primitives/rootless@1.5.2':
|
||||
resolution: {integrity: sha512-9HULb0QAzL2r47CCad0M+NKFtQ+LrGGNHZfteX/ThdGvKIg2o2GYhBooZubTCd/RTu2l2+Nw4s+dEfiDGvdrrQ==}
|
||||
peerDependencies:
|
||||
solid-js: ^1.6.12
|
||||
|
||||
'@solid-primitives/static-store@0.1.2':
|
||||
resolution: {integrity: sha512-ReK+5O38lJ7fT+L6mUFvUr6igFwHBESZF+2Ug842s7fvlVeBdIVEdTCErygff6w7uR6+jrr7J8jQo+cYrEq4Iw==}
|
||||
peerDependencies:
|
||||
solid-js: ^1.6.12
|
||||
|
||||
'@solid-primitives/utils@6.3.2':
|
||||
resolution: {integrity: sha512-hZ/M/qr25QOCcwDPOHtGjxTD8w2mNyVAYvcfgwzBHq2RwNqHNdDNsMZYap20+ruRwW4A3Cdkczyoz0TSxLCAPQ==}
|
||||
peerDependencies:
|
||||
solid-js: ^1.6.12
|
||||
|
||||
'@standard-schema/spec@1.1.0':
|
||||
resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==}
|
||||
|
||||
|
|
@ -3308,13 +3344,67 @@ packages:
|
|||
peerDependencies:
|
||||
tailwindcss: '>=3.0.0 || insiders || >=4.0.0-alpha.20 || >=4.0.0-beta.1'
|
||||
|
||||
'@tanstack/devtools-client@0.0.5':
|
||||
resolution: {integrity: sha512-hsNDE3iu4frt9cC2ppn1mNRnLKo2uc1/1hXAyY9z4UYb+o40M2clFAhiFoo4HngjfGJDV3x18KVVIq7W4Un+zA==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@tanstack/devtools-event-bus@0.4.0':
|
||||
resolution: {integrity: sha512-1t+/csFuDzi+miDxAOh6Xv7VDE80gJEItkTcAZLjV5MRulbO/W8ocjHLI2Do/p2r2/FBU0eKCRTpdqvXaYoHpQ==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@tanstack/devtools-event-client@0.3.5':
|
||||
resolution: {integrity: sha512-RL1f5ZlfZMpghrCIdzl6mLOFLTuhqmPNblZgBaeKfdtk5rfbjykurv+VfYydOFXj0vxVIoA2d/zT7xfD7Ph8fw==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@tanstack/devtools-event-client@0.4.0':
|
||||
resolution: {integrity: sha512-RPfGuk2bDZgcu9bAJodvO2lnZeHuz4/71HjZ0bGb/SPg8+lyTA+RLSKQvo7fSmPSi8/vcH3aKQ8EM9ywf1olaw==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@tanstack/devtools-ui@0.4.4':
|
||||
resolution: {integrity: sha512-5xHXFyX3nom0UaNfiOM92o6ziaHjGo3mcSGe2HD5Xs8dWRZNpdZ0Smd0B9ddEhy0oB+gXyMzZgUJb9DmrZV0Mg==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
solid-js: '>=1.9.7'
|
||||
|
||||
'@tanstack/devtools-utils@0.0.9':
|
||||
resolution: {integrity: sha512-tCObM6wbEjuHeGNs3JDhrqBhoMxpJpVuVIg5Kc33EmUI1ZO7KLpC1277Qf6AmSWy3aVOreGwn3y5bJzxmAJNXg==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
'@types/react': ~19.2.7
|
||||
react: '>=17.0.0'
|
||||
solid-js: '>=1.9.7'
|
||||
vue: '>=3.2.0'
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
react:
|
||||
optional: true
|
||||
solid-js:
|
||||
optional: true
|
||||
vue:
|
||||
optional: true
|
||||
|
||||
'@tanstack/devtools@0.10.1':
|
||||
resolution: {integrity: sha512-1gtPmCDXV4Pl1nVtoqwjV0tc4E9GMuFtlkBX1Lz1KfqI3W9JojT5YsVifOQ/g8BTQ5w5+tyIANwHU7WYgLq/MQ==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
solid-js: '>=1.9.7'
|
||||
|
||||
'@tanstack/form-core@1.27.1':
|
||||
resolution: {integrity: sha512-hPM+0tUnZ2C2zb2TE1lar1JJ0S0cbnQHlUwFcCnVBpMV3rjtUzkoM766gUpWrlmTGCzNad0GbJ0aTxVsjT6J8g==}
|
||||
|
||||
'@tanstack/form-core@1.27.6':
|
||||
resolution: {integrity: sha512-1C4PUpOcCpivddKxtAeqdeqncxnPKiPpTVDRknDExCba+6zCsAjxgL+p3qYA3hu+EFyUAdW71rU+uqYbEa7qqA==}
|
||||
|
||||
'@tanstack/form-devtools@0.2.9':
|
||||
resolution: {integrity: sha512-KOJiwvlFPsHeuWXvHUXRVdciXG1OPhg1c476MsLre0YLdaw1jeMlDYSlqq7sdEULX+2Sg/lhNpX86QbQuxzd2A==}
|
||||
peerDependencies:
|
||||
solid-js: '>=1.9.9'
|
||||
|
||||
'@tanstack/pacer-lite@0.1.1':
|
||||
resolution: {integrity: sha512-y/xtNPNt/YeyoVxE/JCx+T7yjEzpezmbb+toK8DDD1P4m7Kzs5YR956+7OKexG3f8aXgC3rLZl7b1V+yNUSy5w==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@tanstack/pacer@0.15.4':
|
||||
resolution: {integrity: sha512-vGY+CWsFZeac3dELgB6UZ4c7OacwsLb8hvL2gLS6hTgy8Fl0Bm/aLokHaeDIP+q9F9HUZTnp360z9uv78eg8pg==}
|
||||
engines: {node: '>=18'}
|
||||
|
|
@ -3325,6 +3415,20 @@ packages:
|
|||
'@tanstack/query-devtools@5.91.1':
|
||||
resolution: {integrity: sha512-l8bxjk6BMsCaVQH6NzQEE/bEgFy1hAs5qbgXl0xhzezlaQbPk6Mgz9BqEg2vTLPOHD8N4k+w/gdgCbEzecGyNg==}
|
||||
|
||||
'@tanstack/react-devtools@0.9.0':
|
||||
resolution: {integrity: sha512-Lq0svXOTG5N61SHgx8F0on6zz2GB0kmFjN/yyfNLrJyRgJ+U3jYFRd9ti3uBPABsXzHQMHYYujnTXrOYp/OaUg==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
'@types/react': ~19.2.7
|
||||
'@types/react-dom': ~19.2.3
|
||||
react: '>=16.8'
|
||||
react-dom: '>=16.8'
|
||||
|
||||
'@tanstack/react-form-devtools@0.2.9':
|
||||
resolution: {integrity: sha512-wg0xrcVY8evIFGVHrnl9s+/9ENzuVbqv5Ru4HyAJjjL4uECtl6KdDJsi0lZdOyoM1UYEQoVdcN8jfBbxkA3q1g==}
|
||||
peerDependencies:
|
||||
react: ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
|
||||
'@tanstack/react-form@1.27.1':
|
||||
resolution: {integrity: sha512-HKP0Ew2ae9AL5vU1PkJ+oAC2p+xBtA905u0fiNLzlfn1vLkBxenfg5L6TOA+rZITHpQsSo10tqwc5Yw6qn8Mpg==}
|
||||
peerDependencies:
|
||||
|
|
@ -3594,6 +3698,9 @@ packages:
|
|||
'@types/node@20.19.26':
|
||||
resolution: {integrity: sha512-0l6cjgF0XnihUpndDhk+nyD3exio3iKaYROSgvh/qSevPXax3L8p5DBRFjbvalnwatGgHEQn2R88y2fA3g4irg==}
|
||||
|
||||
'@types/node@20.19.27':
|
||||
resolution: {integrity: sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==}
|
||||
|
||||
'@types/papaparse@5.5.1':
|
||||
resolution: {integrity: sha512-esEO+VISsLIyE+JZBmb89NzsYYbpwV8lmv2rPo6oX5y9KhBaIP7hhHgjuTut54qjdKVMufTEcrh5fUl9+58huw==}
|
||||
|
||||
|
|
@ -5629,6 +5736,11 @@ packages:
|
|||
globrex@0.1.2:
|
||||
resolution: {integrity: sha512-uHJgbwAMwNFf5mLst7IWLNg14x1CkeqglJb/K3doi4dw6q2IvAAmM/Y81kevy83wP+Sst+nutFTYOGg3d1lsxg==}
|
||||
|
||||
goober@2.1.18:
|
||||
resolution: {integrity: sha512-2vFqsaDVIT9Gz7N6kAL++pLpp41l3PfDuusHcjnGLfR6+huZkl6ziX+zgVC3ZxpqWhzH6pyDdGrCeDhMIvwaxw==}
|
||||
peerDependencies:
|
||||
csstype: ^3.0.10
|
||||
|
||||
got@11.8.6:
|
||||
resolution: {integrity: sha512-6tfZ91bOr7bOXnK7PRDCGBLa1H4U080YHNaAQ2KsMGlLEzRbk44nsZF2E1IeRc3vtJHPVbKCYgdFbaGO2ljd8g==}
|
||||
engines: {node: '>=10.19.0'}
|
||||
|
|
@ -6203,8 +6315,8 @@ packages:
|
|||
lexical@0.38.2:
|
||||
resolution: {integrity: sha512-JJmfsG3c4gwBHzUGffbV7ifMNkKAWMCnYE3xJl87gty7hjyV5f3xq7eqTjP5HFYvO4XpjJvvWO2/djHp5S10tw==}
|
||||
|
||||
lib0@0.2.115:
|
||||
resolution: {integrity: sha512-noaW4yNp6hCjOgDnWWxW0vGXE3kZQI5Kqiwz+jIWXavI9J9WyfJ9zjsbQlQlgjIbHBrvlA/x3TSIXBUJj+0L6g==}
|
||||
lib0@0.2.116:
|
||||
resolution: {integrity: sha512-4zsosjzmt33rx5XjmFVYUAeLNh+BTeDTiwGdLt4muxiir2btsc60Nal0EvkvDRizg+pnlK1q+BtYi7M+d4eStw==}
|
||||
engines: {node: '>=16'}
|
||||
hasBin: true
|
||||
|
||||
|
|
@ -7641,9 +7753,6 @@ packages:
|
|||
resolution: {integrity: sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==}
|
||||
engines: {node: '>=v12.22.7'}
|
||||
|
||||
scheduler@0.26.0:
|
||||
resolution: {integrity: sha512-NlHwttCI/l5gCPR3D1nNXtWABUmBwvZpEQiD4IXSbIDq8BzLIK/7Ir5gTFSGZDUu37K5cMNp0hFtzO38sC7gWA==}
|
||||
|
||||
scheduler@0.27.0:
|
||||
resolution: {integrity: sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q==}
|
||||
|
||||
|
|
@ -7687,6 +7796,16 @@ packages:
|
|||
serialize-javascript@6.0.2:
|
||||
resolution: {integrity: sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==}
|
||||
|
||||
seroval-plugins@1.3.3:
|
||||
resolution: {integrity: sha512-16OL3NnUBw8JG1jBLUoZJsLnQq0n5Ua6aHalhJK4fMQkz1lqR7Osz1sA30trBtd9VUDc2NgkuRCn8+/pBwqZ+w==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
seroval: ^1.0
|
||||
|
||||
seroval@1.3.2:
|
||||
resolution: {integrity: sha512-RbcPH1n5cfwKrru7v7+zrZvjLurgHhGyso3HTyGtRivGWgYjbOmGuivCQaORNELjNONoK35nj28EoWul9sb1zQ==}
|
||||
engines: {node: '>=10'}
|
||||
|
||||
setimmediate@1.0.5:
|
||||
resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==}
|
||||
|
||||
|
|
@ -7753,6 +7872,9 @@ packages:
|
|||
resolution: {integrity: sha512-QlaZEqcAH3/RtNyet1IPIYPsEWAaYyXXv1Krsi+1L/QHppjX4Ifm8MQsBISz9vE8cHicIq3clogsheili5vhaQ==}
|
||||
engines: {node: '>= 18'}
|
||||
|
||||
solid-js@1.9.10:
|
||||
resolution: {integrity: sha512-Coz956cos/EPDlhs6+jsdTxKuJDPT7B5SVIWgABwROyxjY7Xbr8wkzD68Et+NxnV7DLJ3nJdAC2r9InuV/4Jew==}
|
||||
|
||||
sortablejs@1.15.6:
|
||||
resolution: {integrity: sha512-aNfiuwMEpfBM/CN6LY0ibyhxPfPbyFeBTYJKCvzkJ2GkUpazIt3H+QIPAMHwqQ7tMKaHz1Qj+rJJCqljnf4p3A==}
|
||||
|
||||
|
|
@ -8589,6 +8711,7 @@ packages:
|
|||
whatwg-encoding@3.1.1:
|
||||
resolution: {integrity: sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==}
|
||||
engines: {node: '>=18'}
|
||||
deprecated: Use @exodus/bytes instead for a more spec-conformant and faster implementation
|
||||
|
||||
whatwg-mimetype@3.0.0:
|
||||
resolution: {integrity: sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==}
|
||||
|
|
@ -11539,6 +11662,40 @@ snapshots:
|
|||
|
||||
'@sindresorhus/is@4.6.0': {}
|
||||
|
||||
'@solid-primitives/event-listener@2.4.3(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
|
||||
solid-js: 1.9.10
|
||||
|
||||
'@solid-primitives/keyboard@1.3.3(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@solid-primitives/event-listener': 2.4.3(solid-js@1.9.10)
|
||||
'@solid-primitives/rootless': 1.5.2(solid-js@1.9.10)
|
||||
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
|
||||
solid-js: 1.9.10
|
||||
|
||||
'@solid-primitives/resize-observer@2.1.3(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@solid-primitives/event-listener': 2.4.3(solid-js@1.9.10)
|
||||
'@solid-primitives/rootless': 1.5.2(solid-js@1.9.10)
|
||||
'@solid-primitives/static-store': 0.1.2(solid-js@1.9.10)
|
||||
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
|
||||
solid-js: 1.9.10
|
||||
|
||||
'@solid-primitives/rootless@1.5.2(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
|
||||
solid-js: 1.9.10
|
||||
|
||||
'@solid-primitives/static-store@0.1.2(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
|
||||
solid-js: 1.9.10
|
||||
|
||||
'@solid-primitives/utils@6.3.2(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
solid-js: 1.9.10
|
||||
|
||||
'@standard-schema/spec@1.1.0': {}
|
||||
|
||||
'@standard-schema/utils@0.3.0': {}
|
||||
|
|
@ -11766,14 +11923,84 @@ snapshots:
|
|||
postcss-selector-parser: 6.0.10
|
||||
tailwindcss: 3.4.18(tsx@4.21.0)(yaml@2.8.2)
|
||||
|
||||
'@tanstack/devtools-client@0.0.5':
|
||||
dependencies:
|
||||
'@tanstack/devtools-event-client': 0.4.0
|
||||
|
||||
'@tanstack/devtools-event-bus@0.4.0':
|
||||
dependencies:
|
||||
ws: 8.18.3
|
||||
transitivePeerDependencies:
|
||||
- bufferutil
|
||||
- utf-8-validate
|
||||
|
||||
'@tanstack/devtools-event-client@0.3.5': {}
|
||||
|
||||
'@tanstack/devtools-event-client@0.4.0': {}
|
||||
|
||||
'@tanstack/devtools-ui@0.4.4(csstype@3.2.3)(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
clsx: 2.1.1
|
||||
goober: 2.1.18(csstype@3.2.3)
|
||||
solid-js: 1.9.10
|
||||
transitivePeerDependencies:
|
||||
- csstype
|
||||
|
||||
'@tanstack/devtools-utils@0.0.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@tanstack/devtools-ui': 0.4.4(csstype@3.2.3)(solid-js@1.9.10)
|
||||
optionalDependencies:
|
||||
'@types/react': 19.2.7
|
||||
react: 19.2.3
|
||||
solid-js: 1.9.10
|
||||
transitivePeerDependencies:
|
||||
- csstype
|
||||
|
||||
'@tanstack/devtools@0.10.1(csstype@3.2.3)(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@solid-primitives/event-listener': 2.4.3(solid-js@1.9.10)
|
||||
'@solid-primitives/keyboard': 1.3.3(solid-js@1.9.10)
|
||||
'@solid-primitives/resize-observer': 2.1.3(solid-js@1.9.10)
|
||||
'@tanstack/devtools-client': 0.0.5
|
||||
'@tanstack/devtools-event-bus': 0.4.0
|
||||
'@tanstack/devtools-ui': 0.4.4(csstype@3.2.3)(solid-js@1.9.10)
|
||||
clsx: 2.1.1
|
||||
goober: 2.1.18(csstype@3.2.3)
|
||||
solid-js: 1.9.10
|
||||
transitivePeerDependencies:
|
||||
- bufferutil
|
||||
- csstype
|
||||
- utf-8-validate
|
||||
|
||||
'@tanstack/form-core@1.27.1':
|
||||
dependencies:
|
||||
'@tanstack/devtools-event-client': 0.3.5
|
||||
'@tanstack/pacer': 0.15.4
|
||||
'@tanstack/store': 0.7.7
|
||||
|
||||
'@tanstack/form-core@1.27.6':
|
||||
dependencies:
|
||||
'@tanstack/devtools-event-client': 0.4.0
|
||||
'@tanstack/pacer-lite': 0.1.1
|
||||
'@tanstack/store': 0.7.7
|
||||
|
||||
'@tanstack/form-devtools@0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@tanstack/devtools-ui': 0.4.4(csstype@3.2.3)(solid-js@1.9.10)
|
||||
'@tanstack/devtools-utils': 0.0.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
|
||||
'@tanstack/form-core': 1.27.6
|
||||
clsx: 2.1.1
|
||||
dayjs: 1.11.19
|
||||
goober: 2.1.18(csstype@3.2.3)
|
||||
solid-js: 1.9.10
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
- csstype
|
||||
- react
|
||||
- vue
|
||||
|
||||
'@tanstack/pacer-lite@0.1.1': {}
|
||||
|
||||
'@tanstack/pacer@0.15.4':
|
||||
dependencies:
|
||||
'@tanstack/devtools-event-client': 0.3.5
|
||||
|
|
@ -11783,6 +12010,30 @@ snapshots:
|
|||
|
||||
'@tanstack/query-devtools@5.91.1': {}
|
||||
|
||||
'@tanstack/react-devtools@0.9.0(@types/react-dom@19.2.3(@types/react@19.2.7))(@types/react@19.2.7)(csstype@3.2.3)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@tanstack/devtools': 0.10.1(csstype@3.2.3)(solid-js@1.9.10)
|
||||
'@types/react': 19.2.7
|
||||
'@types/react-dom': 19.2.3(@types/react@19.2.7)
|
||||
react: 19.2.3
|
||||
react-dom: 19.2.3(react@19.2.3)
|
||||
transitivePeerDependencies:
|
||||
- bufferutil
|
||||
- csstype
|
||||
- solid-js
|
||||
- utf-8-validate
|
||||
|
||||
'@tanstack/react-form-devtools@0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)':
|
||||
dependencies:
|
||||
'@tanstack/devtools-utils': 0.0.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
|
||||
'@tanstack/form-devtools': 0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
|
||||
react: 19.2.3
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
- csstype
|
||||
- solid-js
|
||||
- vue
|
||||
|
||||
'@tanstack/react-form@1.27.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
|
||||
dependencies:
|
||||
'@tanstack/form-core': 1.27.1
|
||||
|
|
@ -12091,6 +12342,11 @@ snapshots:
|
|||
dependencies:
|
||||
undici-types: 6.21.0
|
||||
|
||||
'@types/node@20.19.27':
|
||||
dependencies:
|
||||
undici-types: 6.21.0
|
||||
optional: true
|
||||
|
||||
'@types/papaparse@5.5.1':
|
||||
dependencies:
|
||||
'@types/node': 18.15.0
|
||||
|
|
@ -14486,6 +14742,10 @@ snapshots:
|
|||
|
||||
globrex@0.1.2: {}
|
||||
|
||||
goober@2.1.18(csstype@3.2.3):
|
||||
dependencies:
|
||||
csstype: 3.2.3
|
||||
|
||||
got@11.8.6:
|
||||
dependencies:
|
||||
'@sindresorhus/is': 4.6.0
|
||||
|
|
@ -14512,7 +14772,7 @@ snapshots:
|
|||
|
||||
happy-dom@20.0.11:
|
||||
dependencies:
|
||||
'@types/node': 20.19.26
|
||||
'@types/node': 20.19.27
|
||||
'@types/whatwg-mimetype': 3.0.2
|
||||
whatwg-mimetype: 3.0.0
|
||||
optional: true
|
||||
|
|
@ -15127,7 +15387,7 @@ snapshots:
|
|||
|
||||
lexical@0.38.2: {}
|
||||
|
||||
lib0@0.2.115:
|
||||
lib0@0.2.116:
|
||||
dependencies:
|
||||
isomorphic.js: 0.2.5
|
||||
|
||||
|
|
@ -17044,8 +17304,6 @@ snapshots:
|
|||
dependencies:
|
||||
xmlchars: 2.2.0
|
||||
|
||||
scheduler@0.26.0: {}
|
||||
|
||||
scheduler@0.27.0: {}
|
||||
|
||||
schema-utils@2.7.1:
|
||||
|
|
@ -17089,6 +17347,12 @@ snapshots:
|
|||
dependencies:
|
||||
randombytes: 2.1.0
|
||||
|
||||
seroval-plugins@1.3.3(seroval@1.3.2):
|
||||
dependencies:
|
||||
seroval: 1.3.2
|
||||
|
||||
seroval@1.3.2: {}
|
||||
|
||||
setimmediate@1.0.5: {}
|
||||
|
||||
sha.js@2.4.12:
|
||||
|
|
@ -17203,6 +17467,12 @@ snapshots:
|
|||
|
||||
smol-toml@1.5.2: {}
|
||||
|
||||
solid-js@1.9.10:
|
||||
dependencies:
|
||||
csstype: 3.2.3
|
||||
seroval: 1.3.2
|
||||
seroval-plugins: 1.3.3(seroval@1.3.2)
|
||||
|
||||
sortablejs@1.15.6: {}
|
||||
|
||||
source-list-map@2.0.1: {}
|
||||
|
|
@ -17734,10 +18004,10 @@ snapshots:
|
|||
optionalDependencies:
|
||||
'@types/react': 19.2.7
|
||||
|
||||
use-context-selector@2.0.0(react@19.2.3)(scheduler@0.26.0):
|
||||
use-context-selector@2.0.0(react@19.2.3)(scheduler@0.27.0):
|
||||
dependencies:
|
||||
react: 19.2.3
|
||||
scheduler: 0.26.0
|
||||
scheduler: 0.27.0
|
||||
|
||||
use-isomorphic-layout-effect@1.2.1(@types/react@19.2.7)(react@19.2.3):
|
||||
dependencies:
|
||||
|
|
@ -18202,7 +18472,7 @@ snapshots:
|
|||
|
||||
yjs@13.6.27:
|
||||
dependencies:
|
||||
lib0: 0.2.115
|
||||
lib0: 0.2.116
|
||||
|
||||
yocto-queue@0.1.0: {}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import {
|
|||
} from '@tanstack/react-query'
|
||||
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
|
||||
import { DocumentActionType } from '@/models/datasets'
|
||||
import { del, get, patch } from '../base'
|
||||
import { del, get, patch, post } from '../base'
|
||||
import { pauseDocIndexing, resumeDocIndexing } from '../datasets'
|
||||
import { useInvalid } from '../use-base'
|
||||
|
||||
|
|
@ -163,3 +163,15 @@ export const useDocumentResume = () => {
|
|||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const useDocumentBatchRetryIndex = () => {
|
||||
return useMutation({
|
||||
mutationFn: ({ datasetId, documentIds }: { datasetId: string, documentIds: string[] }) => {
|
||||
return post<CommonResponse>(`/datasets/${datasetId}/retry`, {
|
||||
body: {
|
||||
document_ids: documentIds,
|
||||
},
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue