Merge branch 'main' into feat/hitl-frontend

This commit is contained in:
twwu 2025-12-29 10:35:51 +08:00
commit c716c4ccbe
54 changed files with 4074 additions and 1147 deletions

8
.claude/settings.json Normal file
View File

@ -0,0 +1,8 @@
{
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -1,19 +0,0 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,34 +0,0 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -313,17 +313,20 @@ class StreamableHTTPTransport:
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -381,10 +381,9 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
valid_dataset_documents = {}
image_doc_ids = []
image_doc_ids: list[Any] = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
@ -417,28 +416,39 @@ class RetrievalService:
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment_ids = []
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map = {}
child_chunk_map = {}
doc_segment_map = {}
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
attachment_map[attachment["segment_id"]] = attachment
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
if attachment["segment_id"] in attachment_map:
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
else:
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
if attachment["segment_id"] in doc_segment_map:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
child_chunk_map[i.segment_id] = i
doc_segment_map[i.segment_id] = i.index_node_id
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)
else:
child_chunk_map[i.segment_id] = [i]
if i.segment_id in doc_segment_map:
doc_segment_map[i.segment_id].append(i.index_node_id)
else:
doc_segment_map[i.segment_id] = [i.index_node_id]
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
@ -448,7 +458,7 @@ class RetrievalService:
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
@ -461,95 +471,86 @@ class RetrievalService:
segments.extend(index_node_segments)
for segment in segments:
doc_id = doc_segment_map.get(segment.id)
child_chunk = child_chunk_map.get(segment.id)
attachment_info = attachment_map.get(segment.id)
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if doc_id:
document = doc_to_document_map[doc_id]
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
document.metadata.get("document_id")
)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunks or attachment_infos:
child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id]
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
map_detail = {
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"],
document.metadata.get("score", 0.0) if document else 0.0,
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
if attachment_info:
if segment.id in segment_file_map:
segment_file_map[segment.id].append(attachment_info)
else:
segment_file_map[segment.id] = [attachment_info]
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", 0.0), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
segment_child_map[segment.id] = map_detail
record: dict[str, Any] = {
"segment": segment,
}
records.append(record)
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
}
records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
result = []
result: list[RetrievalSegments] = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
raw_child_chunks = record.get("child_chunks")
child_chunks_list: list[RetrievalChildChunk] | None = None
if isinstance(raw_child_chunks, list):
# Sort by score descending
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
child_chunks_list = [
RetrievalChildChunk(
id=chunk["id"],
content=chunk["content"],
score=chunk.get("score", 0.0),
position=chunk["position"],
)
for chunk in sorted_chunks
]
# Extract files, ensuring it's a list or None
files = record.get("files")
@ -566,11 +567,11 @@ class RetrievalService:
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
segment=segment, child_chunks=child_chunks_list, score=score, files=files
)
result.append(retrieval_segment)
return result
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
except Exception as e:
db.session.rollback()
raise e

View File

@ -255,7 +255,10 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
if not cur.fetchone():
cur.execute("CREATE EXTENSION vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
self.process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = self.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -1168,8 +1168,9 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
@classmethod
def process_metadata_filter_func(
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
@ -1218,6 +1219,20 @@ class DatasetRetrieval:
case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True
filters.append(literal(condition == "not in"))
else:
op = json_field.in_ if condition == "in" else json_field.notin_
filters.append(op(value_list))
case _:
pass

View File

@ -6,7 +6,15 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -53,10 +61,19 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
elif isinstance(content, ImageContent | AudioContent):
yield self.create_blob_message(
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
)
elif isinstance(content, EmbeddedResource):
resource = content.resource
if isinstance(resource, TextResourceContents):
yield self.create_text_message(resource.text)
elif isinstance(resource, BlobResourceContents):
mime_type = resource.mimeType or "application/octet-stream"
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
else:
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
else:
logger.warning("Unsupported content type=%s", type(content))
@ -101,14 +118,6 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -14,7 +14,8 @@ from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService

View File

@ -21,7 +21,7 @@ from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -141,7 +141,7 @@ class AsyncWorkflowService:
trigger_log_repo.update(trigger_log)
session.commit()
raise InvokeRateLimitError(
raise WorkflowQuotaLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e

View File

@ -110,5 +110,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)

View File

@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
pass
class InvokeRateLimitError(Exception):
"""Raised when rate limit is exceeded for workflow invocations."""
class WorkflowQuotaLimitError(Exception):
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
pass

View File

@ -146,7 +146,7 @@ class PluginParameterService:
provider,
action,
resolved_credentials,
CredentialType.API_KEY.value,
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
parameter,
)
.options

View File

@ -868,48 +868,111 @@ class TriggerProviderService:
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
with Session(db.engine, expire_on_commit=False) as session:
try:
# Get subscription within the transaction
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
# TODO: Trying to invoke update api of the plugin trigger provider
# Decrypt existing credentials for merging
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
merged_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
# Delete the previous subscription
user_id = subscription.user_id
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=subscription.credentials,
credential_type=credential_type,
)
user_id = subscription.user_id
# Create a new subscription with the same subscription_id and endpoint_id
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
TriggerProviderService.update_trigger_subscription(
tenant_id=tenant_id,
subscription_id=subscription.id,
name=name,
parameters=parameters,
credentials=credentials,
properties=new_subscription.properties,
expires_at=new_subscription.expires_at,
)
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented,
# delete the previous subscription and create a new one
# Unsubscribe the previous subscription (external call, but we'll handle errors)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=decrypted_credentials,
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
# Continue anyway - the subscription might already be deleted externally
# Create a new subscription with the same subscription_id and endpoint_id (external call)
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=merged_credentials,
credential_type=credential_type,
)
# Update the subscription in the same transaction
# Inline update logic to reuse the same session
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing and existing.id != subscription.id:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update parameters
subscription.parameters = dict(parameters)
# Update credentials with merged (and encrypted) values
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
# Update properties
if new_subscription.properties:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
# Update expiration timestamp
if new_subscription.expires_at is not None:
subscription.expires_at = new_subscription.expires_at
# Commit the transaction
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
except Exception as e:
# Rollback on any error
session.rollback()
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
raise

View File

@ -863,10 +863,18 @@ class WebhookService:
not_found_in_cache.append(node_id)
continue
with Session(db.engine) as session:
try:
# lock the concurrent webhook trigger creation
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
lock = redis_client.lock(lock_key, timeout=10)
lock_acquired = False
try:
# acquire the lock with blocking and timeout
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
if not lock_acquired:
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
with Session(db.engine) as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@ -903,11 +911,16 @@ class WebhookService:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
# release the lock only if it was acquired
if lock_acquired:
try:
lock.release()
except Exception:
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
@classmethod
def generate_webhook_id(cls) -> str:

View File

@ -0,0 +1,682 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from extensions.ext_database import db
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
class TestTriggerProviderService:
"""Integration tests for TriggerProviderService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns
mock_provider_controller = MagicMock()
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
mock_provider_controller.get_properties_schema.return_value = MagicMock()
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
# Mock redis lock
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock(return_value=None)
mock_lock.__exit__ = MagicMock(return_value=None)
mock_redis_client.lock.return_value = mock_lock
# Setup account feature service mock
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
yield {
"trigger_manager": mock_trigger_manager,
"redis_client": mock_redis_client,
"delete_cache": mock_delete_cache,
"provider_controller": mock_provider_controller,
"account_feature_service": mock_account_feature_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
from services.account_service import AccountService, TenantService
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies[
"trigger_manager"
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
return account, tenant
def _create_test_subscription(
self,
db_session_with_containers,
tenant_id,
user_id,
provider_id,
credential_type,
credentials,
mock_external_service_dependencies,
):
"""
Helper method to create a test trigger subscription.
Args:
db_session_with_containers: Database session
tenant_id: Tenant ID
user_id: User ID
provider_id: Provider ID
credential_type: Credential type
credentials: Credentials dict
mock_external_service_dependencies: Mock dependencies
Returns:
TriggerSubscription: Created subscription instance
"""
fake = Faker()
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to encrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
# Create encrypter for credentials
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription = TriggerSubscription(
name=fake.word(),
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
endpoint_id=fake.uuid4(),
parameters={"param1": "value1"},
properties={"prop1": "value1"},
credentials=dict(credential_encrypter.encrypt(credentials)),
credential_type=credential_type.value,
credential_expires_at=-1,
expires_at=-1,
)
db.session.add(subscription)
db.session.commit()
db.session.refresh(subscription)
return subscription
def test_rebuild_trigger_subscription_success_with_merged_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
This test verifies:
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
- Single transaction wraps all operations
- Merged credentials are used for subscribe and update
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription with credentials
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
# and new value for api_secret (should update)
new_credentials = {
"api_key": HIDDEN_VALUE, # Should be replaced with original
"api_secret": "new-secret-value", # Should be updated
}
# Mock subscribe_trigger to return a new subscription entity
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={"param1": "value1"},
properties={"prop1": "new_prop_value"},
expires_at=1234567890,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Mock unsubscribe_trigger
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={"param1": "updated_value"},
name="updated_name",
)
# Verify unsubscribe was called with decrypted original credentials
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
# Verify database state was updated
db.session.refresh(subscription)
assert subscription.name == "updated_name"
assert subscription.parameters == {"param1": "updated_value"}
# Verify credentials in DB were updated with merged values (decrypt to check)
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to decrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant.id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
# Verify cache was cleared
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
tenant_id=tenant.id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
def test_rebuild_trigger_subscription_with_all_new_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are new (no HIDDEN_VALUE).
This test verifies:
- All new credentials are used when no HIDDEN_VALUE is present
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All new credentials (no HIDDEN_VALUE)
new_credentials = {
"api_key": "completely-new-key",
"api_secret": "completely-new-secret",
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all new credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == "completely-new-key"
assert subscribe_credentials["api_secret"] == "completely-new-secret"
def test_rebuild_trigger_subscription_with_all_hidden_values(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
This test verifies:
- All HIDDEN_VALUE credentials are replaced with existing values
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All HIDDEN_VALUE (should preserve all original)
new_credentials = {
"api_key": HIDDEN_VALUE,
"api_secret": HIDDEN_VALUE,
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all original credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
This test verifies:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Original has only api_key
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
new_credentials = {
"api_key": HIDDEN_VALUE,
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
def test_rebuild_trigger_subscription_rollback_on_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that transaction is rolled back on error.
This test verifies:
- Database transaction is rolled back when an error occurs
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
original_name = subscription.name
original_parameters = subscription.parameters.copy()
# Make subscribe_trigger raise an error
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
"Subscribe failed"
)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild and expect error
with pytest.raises(ValueError, match="Subscribe failed"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscription state was not changed (rolled back)
db.session.refresh(subscription)
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that unsubscribe errors are handled gracefully and operation continues.
This test verifies:
- Unsubscribe errors are caught and logged but don't stop the rebuild
- Rebuild continues even if unsubscribe fails
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Make unsubscribe_trigger raise an error (should be caught and continue)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
"Unsubscribe failed"
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Execute rebuild - should succeed despite unsubscribe error
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscribe was still called (operation continued)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
# Verify subscription was updated
db.session.refresh(subscription)
assert subscription.parameters == {}
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when subscription is not found.
This test verifies:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
with pytest.raises(ValueError, match="not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake_subscription_id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when provider is not found.
This test verifies:
- Proper error is raised when provider doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
# Make get_trigger_provider return None
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
with pytest.raises(ValueError, match="Provider.*not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake.uuid4(),
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_unsupported_credential_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when credential type is not supported for rebuild.
This test verifies:
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.UNAUTHORIZED # Not supported
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{},
mock_external_service_dependencies,
)
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that name uniqueness is checked when updating name.
This test verifies:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create first subscription
subscription1 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key1"},
mock_external_service_dependencies,
)
# Create second subscription with different name
subscription2 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key2"},
mock_external_service_dependencies,
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription2.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Try to rename subscription2 to subscription1's name (should fail)
with pytest.raises(ValueError, match="already exists"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription2.id,
credentials={"api_key": "new-key"},
parameters={},
name=subscription1.name, # Conflicting name
)

View File

@ -0,0 +1,327 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector,
PGVectorConfig,
)
class TestPGVector(unittest.TestCase):
def setUp(self):
self.config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=False,
)
self.collection_name = "test_collection"
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init(self, mock_pool_class):
"""Test PGVector initialization."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, self.config)
assert pgvector._collection_name == self.collection_name
assert pgvector.table_name == f"embedding_{self.collection_name}"
assert pgvector.get_type() == "pgvector"
assert pgvector.pool is not None
assert pgvector.pg_bigm is False
assert pgvector.index_hash is not None
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init_with_pg_bigm(self, mock_pool_class):
"""Test PGVector initialization with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, config)
assert pgvector.pg_bigm is True
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_basic(self, mock_redis, mock_pool_class):
"""Test basic collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify SQL execution calls
assert mock_cursor.execute.called
# Check that CREATE TABLE was called with correct dimension
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(1536)" in create_table_calls[0][0][0]
# Check that CREATE INDEX was called (dimension <= 2000)
create_index_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
]
assert len(create_index_calls) == 1
# Verify Redis cache was set
mock_redis.set.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
"""Test collection creation with dimension > 2000 (no HNSW index)."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(3072) # Dimension > 2000
# Check that CREATE TABLE was called
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(3072)" in create_table_calls[0][0][0]
# Check that HNSW index was NOT created (dimension > 2000)
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
assert len(hnsw_index_calls) == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
"""Test collection creation with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, config)
pgvector._create_collection(1536)
# Check that pg_bigm index was created
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
assert len(bigm_index_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
"""Test that vector extension is created if it doesn't exist."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
# First call: vector extension doesn't exist
mock_cursor.fetchone.return_value = None
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that CREATE EXTENSION was called
create_extension_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
]
assert len(create_extension_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
"""Test that collection creation is skipped when cache exists."""
# Mock Redis operations - cache exists
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1 # Cache exists
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that no SQL was executed (early return due to cache)
assert mock_cursor.execute.call_count == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
"""Test that Redis lock is used during collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify Redis lock was acquired with correct lock name
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
# Verify lock context manager was entered and exited
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_get_cursor_context_manager(self, mock_pool_class):
"""Test that _get_cursor properly manages connection lifecycle."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
with pgvector._get_cursor() as cur:
assert cur == mock_cursor
# Verify connection lifecycle methods were called
mock_pool.getconn.assert_called_once()
mock_cursor.close.assert_called_once()
mock_conn.commit.assert_called_once()
mock_pool.putconn.assert_called_once_with(mock_conn)
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"user": ""}, # Test empty user
{"password": ""}, # Test empty password
{"database": ""}, # Test empty database
{"min_connection": 0}, # Test invalid min_connection
{"max_connection": 0}, # Test invalid max_connection
{"min_connection": 10, "max_connection": 5}, # Test min > max
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 5432,
"user": "test_user",
"password": "test_password",
"database": "test_db",
"min_connection": 1,
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
PGVectorConfig(**config)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,873 @@
"""
Unit tests for DatasetRetrieval.process_metadata_filter_func.
This module provides comprehensive test coverage for the process_metadata_filter_func
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
filter expressions based on metadata filtering conditions.
Conditions Tested:
==================
1. **String Conditions**: contains, not contains, start with, end with
2. **Equality Conditions**: is / =, is not /
3. **Null Conditions**: empty, not empty
4. **Numeric Comparisons**: before / <, after / >, / <=, / >=
5. **List Conditions**: in
6. **Edge Cases**: None values, different data types (str, int, float)
Test Architecture:
==================
- Direct instantiation of DatasetRetrieval
- Mocking of DatasetDocument model attributes
- Verification of SQLAlchemy filter expressions
- Follows Arrange-Act-Assert (AAA) pattern
Running Tests:
==============
# Run all tests in this module
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
# Run a specific test
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
TestProcessMetadataFilterFunc::test_contains_condition -v
"""
from unittest.mock import MagicMock
import pytest
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
class TestProcessMetadataFilterFunc:
"""
Comprehensive test suite for process_metadata_filter_func method.
This test class validates all metadata filtering conditions supported by
the DatasetRetrieval class, including string operations, numeric comparisons,
null checks, and list operations.
Method Signature:
==================
def process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
) -> list:
The method builds SQLAlchemy filter expressions by:
1. Validating value is not None (except for empty/not empty conditions)
2. Using DatasetDocument.doc_metadata JSON field operations
3. Adding appropriate SQLAlchemy expressions to the filters list
4. Returning the updated filters list
Mocking Strategy:
==================
- Mock DatasetDocument.doc_metadata to avoid database dependencies
- Verify filter expressions are created correctly
- Test with various data types (str, int, float, list)
"""
@pytest.fixture
def retrieval(self):
"""
Create a DatasetRetrieval instance for testing.
Returns:
DatasetRetrieval: Instance to test process_metadata_filter_func
"""
return DatasetRetrieval()
@pytest.fixture
def mock_doc_metadata(self):
"""
Mock the DatasetDocument.doc_metadata JSON field.
The method uses DatasetDocument.doc_metadata[metadata_name] to access
JSON fields. We mock this to avoid database dependencies.
Returns:
Mock: Mocked doc_metadata attribute
"""
mock_metadata_field = MagicMock()
# Create mock for string access
mock_string_access = MagicMock()
mock_string_access.like = MagicMock()
mock_string_access.notlike = MagicMock()
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
mock_string_access.in_ = MagicMock(return_value=MagicMock())
# Create mock for float access (for numeric comparisons)
mock_float_access = MagicMock()
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
# Create mock for null checks
mock_null_access = MagicMock()
mock_null_access.is_ = MagicMock(return_value=MagicMock())
mock_null_access.isnot = MagicMock(return_value=MagicMock())
# Setup __getitem__ to return appropriate mock based on usage
def getitem_side_effect(name):
if name in ["author", "title", "category"]:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
# ==================== String Condition Tests ====================
def test_contains_condition_string_value(self, retrieval):
"""
Test 'contains' condition with string value.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value% syntax
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = "John"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_contains_condition(self, retrieval):
"""
Test 'not contains' condition.
Verifies:
- Filters list is populated with NOT LIKE expression
- Pattern matching uses %value% syntax with negation
"""
filters = []
sequence = 0
condition = "not contains"
metadata_name = "title"
value = "banned"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_start_with_condition(self, retrieval):
"""
Test 'start with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses value% syntax
"""
filters = []
sequence = 0
condition = "start with"
metadata_name = "category"
value = "tech"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_end_with_condition(self, retrieval):
"""
Test 'end with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value syntax
"""
filters = []
sequence = 0
condition = "end with"
metadata_name = "filename"
value = ".pdf"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Equality Condition Tests ====================
def test_is_condition_with_string_value(self, retrieval):
"""
Test 'is' (=) condition with string value.
Verifies:
- Filters list is populated with equality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = "Jane Doe"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_equals_condition_with_string_value(self, retrieval):
"""
Test '=' condition with string value.
Verifies:
- Same behavior as 'is' condition
- String comparison is used
"""
filters = []
sequence = 0
condition = "="
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_int_value(self, retrieval):
"""
Test 'is' condition with integer value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_float_value(self, retrieval):
"""
Test 'is' condition with float value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "price"
value = 19.99
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_string_value(self, retrieval):
"""
Test 'is not' () condition with string value.
Verifies:
- Filters list is populated with inequality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "author"
value = "Unknown"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_equals_condition(self, retrieval):
"""
Test '' condition with string value.
Verifies:
- Same behavior as 'is not' condition
- Inequality expression is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "category"
value = "archived"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_numeric_value(self, retrieval):
"""
Test 'is not' condition with numeric value.
Verifies:
- Numeric inequality comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Null Condition Tests ====================
def test_empty_condition(self, retrieval):
"""
Test 'empty' condition (null check).
Verifies:
- Filters list is populated with IS NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "empty"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_empty_condition(self, retrieval):
"""
Test 'not empty' condition (not null check).
Verifies:
- Filters list is populated with IS NOT NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "not empty"
metadata_name = "description"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Numeric Comparison Tests ====================
def test_before_condition(self, retrieval):
"""
Test 'before' (<) condition.
Verifies:
- Filters list is populated with less than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "before"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_condition(self, retrieval):
"""
Test '<' condition.
Verifies:
- Same behavior as 'before' condition
- Less than expression is used
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "price"
value = 100.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_after_condition(self, retrieval):
"""
Test 'after' (>) condition.
Verifies:
- Filters list is populated with greater than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "after"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_condition(self, retrieval):
"""
Test '>' condition.
Verifies:
- Same behavior as 'after' condition
- Greater than expression is used
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with less than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "price"
value = 50.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_ascii(self, retrieval):
"""
Test '<=' condition.
Verifies:
- Same behavior as '' condition
- Less than or equal expression is used
"""
filters = []
sequence = 0
condition = "<="
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with greater than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "rating"
value = 3.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_ascii(self, retrieval):
"""
Test '>=' condition.
Verifies:
- Same behavior as '' condition
- Greater than or equal expression is used
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== List/In Condition Tests ====================
def test_in_condition_with_comma_separated_string(self, retrieval):
"""
Test 'in' condition with comma-separated string value.
Verifies:
- String is split into list
- Whitespace is trimmed from each value
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "tech, science, AI "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_list_value(self, retrieval):
"""
Test 'in' condition with list value.
Verifies:
- List is processed correctly
- None values are filtered out
- IN expression is created with valid values
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "tags"
value = ["python", "javascript", None, "golang"]
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_tuple_value(self, retrieval):
"""
Test 'in' condition with tuple value.
Verifies:
- Tuple is processed like a list
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ("tech", "science", "ai")
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_empty_string(self, retrieval):
"""
Test 'in' condition with empty string value.
Verifies:
- Empty string results in literal(False) filter
- No valid values to match
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# Verify it's a literal(False) expression
# This is a bit tricky to test without access to the actual expression
def test_in_condition_with_only_whitespace(self, retrieval):
"""
Test 'in' condition with whitespace-only string value.
Verifies:
- Whitespace-only string results in literal(False) filter
- All values are stripped and filtered out
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = " , , "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_single_string(self, retrieval):
"""
Test 'in' condition with single non-comma string.
Verifies:
- Single string is treated as single-item list
- IN expression is created with one value
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Edge Case Tests ====================
def test_none_value_with_non_empty_condition(self, retrieval):
"""
Test None value with conditions that require value.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values (except empty/not empty)
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0 # No filter added
def test_none_value_with_equals_condition(self, retrieval):
"""
Test None value with 'is' (=) condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_none_value_with_numeric_condition(self, retrieval):
"""
Test None value with numeric comparison condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "year"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_existing_filters_preserved(self, retrieval):
"""
Test that existing filters are preserved.
Verifies:
- Existing filters in the list are not removed
- New filters are appended to the list
"""
existing_filter = MagicMock()
filters = [existing_filter]
sequence = 0
condition = "contains"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 2
assert filters[0] == existing_filter
def test_multiple_filters_accumulated(self, retrieval):
"""
Test multiple calls to accumulate filters.
Verifies:
- Each call adds a new filter to the list
- All filters are preserved across calls
"""
filters = []
# First filter
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
assert len(filters) == 1
# Second filter
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
assert len(filters) == 2
# Third filter
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
assert len(filters) == 3
def test_unknown_condition(self, retrieval):
"""
Test unknown/unsupported condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for unknown conditions
"""
filters = []
sequence = 0
condition = "unknown_condition"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_empty_string_value_with_contains(self, retrieval):
"""
Test empty string value with 'contains' condition.
Verifies:
- Filter is added even with empty string
- LIKE expression is created
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_special_characters_in_value(self, retrieval):
"""
Test special characters in value string.
Verifies:
- Special characters are handled in value
- LIKE expression is created correctly
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "title"
value = "C++ & Python's features"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_zero_value_with_numeric_condition(self, retrieval):
"""
Test zero value with numeric comparison condition.
Verifies:
- Zero is treated as valid value
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "price"
value = 0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_negative_value_with_numeric_condition(self, retrieval):
"""
Test negative value with numeric comparison condition.
Verifies:
- Negative numbers are handled correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "temperature"
value = -10.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_float_value_with_integer_comparison(self, retrieval):
"""
Test float value with numeric comparison condition.
Verifies:
- Float values work correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1

View File

@ -0,0 +1,122 @@
import base64
from unittest.mock import Mock, patch
import pytest
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextResourceContents,
)
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.mcp_tool.tool import MCPTool
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
identity = ToolIdentity(
author="test",
name="test_mcp_tool",
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
provider="test_provider",
)
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
runtime = Mock(spec=ToolRuntime)
runtime.credentials = {}
return MCPTool(
entity=entity,
runtime=runtime,
tenant_id="test_tenant",
icon="",
server_url="https://server.invalid",
provider_id="provider_1",
headers={},
)
class TestMCPToolInvoke:
@pytest.mark.parametrize(
("content_factory", "mime_type"),
[
(
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
"image/png",
),
(
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
"audio/mpeg",
),
],
)
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
tool = _make_mcp_tool()
raw = b"\x00\x01test-bytes\x02"
b64 = base64.b64encode(raw).decode()
content = content_factory(b64, mime_type)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": mime_type}
def test_invoke_embedded_text_resource_yields_text(self) -> None:
tool = _make_mcp_tool()
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
content = EmbeddedResource(type="resource", resource=text_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
assert msg.message.text == "hello world"
@pytest.mark.parametrize(
("mime_type", "expected_mime"),
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
)
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
tool = _make_mcp_tool()
raw = b"binary-data"
b64 = base64.b64encode(raw).decode()
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
content = EmbeddedResource(type="resource", resource=blob_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": expected_mime}
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
tool = _make_mcp_tool(output_schema={"type": "object"})
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
# Expect two variable messages corresponding to keys a and b
assert len(messages) == 2
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
# Validate values
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
assert values == {"a": 1, "b": "x"}

View File

@ -3072,11 +3072,11 @@ wheels = [
[[package]]
name = "json-repair"
version = "0.54.1"
version = "0.54.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/00/46/d3a4d9a3dad39bb4a2ad16b8adb9fe2e8611b20b71197fe33daa6768e85d/json_repair-0.54.1.tar.gz", hash = "sha256:d010bc31f1fc66e7c36dc33bff5f8902674498ae5cb8e801ad455a53b455ad1d", size = 38555, upload-time = "2025-11-19T14:55:24.265Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/86/48b12ac02032f121ac7e5f11a32143edca6c1e3d19ffc54d6fb9ca0aafd0/json_repair-0.54.3.tar.gz", hash = "sha256:e50feec9725e52ac91f12184609754684ac1656119dfbd31de09bdaf9a1d8bf6", size = 38626, upload-time = "2025-12-15T09:41:58.594Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/96/c9aad7ee949cc1bf15df91f347fbc2d3bd10b30b80c7df689ce6fe9332b5/json_repair-0.54.1-py3-none-any.whl", hash = "sha256:016160c5db5d5fe443164927bb58d2dfbba5f43ad85719fa9bc51c713a443ab1", size = 29311, upload-time = "2025-11-19T14:55:22.886Z" },
{ url = "https://files.pythonhosted.org/packages/e9/08/abe317237add63c3e62f18a981bccf92112b431835b43d844aedaf61f4a0/json_repair-0.54.3-py3-none-any.whl", hash = "sha256:4cdc132ee27d4780576f71bf27a113877046224a808bfc17392e079cb344fb81", size = 29357, upload-time = "2025-12-15T09:41:57.436Z" },
]
[[package]]

View File

@ -54,17 +54,17 @@
"publish:npm": "./scripts/publish.sh"
},
"dependencies": {
"axios": "^1.3.5"
"axios": "^1.13.2"
},
"devDependencies": {
"@eslint/js": "^9.2.0",
"@types/node": "^20.11.30",
"@eslint/js": "^9.39.2",
"@types/node": "^25.0.3",
"@typescript-eslint/eslint-plugin": "^8.50.1",
"@typescript-eslint/parser": "^8.50.1",
"@vitest/coverage-v8": "1.6.1",
"eslint": "^9.2.0",
"@vitest/coverage-v8": "4.0.16",
"eslint": "^9.39.2",
"tsup": "^8.5.1",
"typescript": "^5.4.5",
"vitest": "^1.5.0"
"typescript": "^5.9.3",
"vitest": "^4.0.16"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,2 @@
onlyBuiltDependencies:
- esbuild

View File

@ -47,6 +47,12 @@ const getCheckboxDefaultSelectValue = (value: InputVar['default']) => {
const parseCheckboxSelectValue = (value: string) =>
value === CHECKBOX_DEFAULT_TRUE_VALUE
const normalizeSelectDefaultValue = (inputVar: InputVar) => {
if (inputVar.type === InputVarType.select && inputVar.default === '')
return { ...inputVar, default: undefined }
return inputVar
}
export type IConfigModalProps = {
isCreate?: boolean
payload?: InputVar
@ -67,7 +73,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
}) => {
const { modelConfig } = useContext(ConfigContext)
const { t } = useTranslation()
const [tempPayload, setTempPayload] = useState<InputVar>(() => payload || getNewVarInWorkflow('') as any)
const [tempPayload, setTempPayload] = useState<InputVar>(() => normalizeSelectDefaultValue(payload || getNewVarInWorkflow('') as any))
const { type, label, variable, options, max_length } = tempPayload
const modalRef = useRef<HTMLDivElement>(null)
const appDetail = useAppStore(state => state.appDetail)
@ -182,6 +188,8 @@ const ConfigModal: FC<IConfigModalProps> = ({
const newPayload = produce(tempPayload, (draft) => {
draft.type = type
if (type === InputVarType.select)
draft.default = undefined
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
if (key !== 'max_length')

View File

@ -0,0 +1,141 @@
import type { DataSet } from '@/models/datasets'
import { act, fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { describe, expect, it, vi } from 'vitest'
import { IndexingType } from '@/app/components/datasets/create/step-two'
import { DatasetPermission } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import SelectDataSet from './index'
vi.mock('@/i18n-config/i18next-config', () => ({
__esModule: true,
default: {
changeLanguage: vi.fn(),
addResourceBundle: vi.fn(),
use: vi.fn().mockReturnThis(),
init: vi.fn(),
addResource: vi.fn(),
hasResourceBundle: vi.fn().mockReturnValue(true),
},
}))
const mockUseInfiniteScroll = vi.fn()
vi.mock('ahooks', async (importOriginal) => {
const actual = await importOriginal()
return {
...(typeof actual === 'object' && actual !== null ? actual : {}),
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
}
})
const mockUseInfiniteDatasets = vi.fn()
vi.mock('@/service/knowledge/use-dataset', () => ({
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
}))
vi.mock('@/hooks/use-knowledge', () => ({
useKnowledge: () => ({
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
}),
}))
const baseProps = {
isShow: true,
onClose: vi.fn(),
selectedIds: [] as string[],
onSelect: vi.fn(),
}
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
id: 'dataset-id',
name: 'Dataset Name',
provider: 'internal',
icon_info: {
icon_type: 'emoji',
icon: '💾',
icon_background: '#fff',
icon_url: '',
},
embedding_available: true,
is_multimodal: false,
description: '',
permission: DatasetPermission.allTeamMembers,
indexing_technique: IndexingType.ECONOMICAL,
retrieval_model_dict: {
search_method: RETRIEVE_METHOD.fullText,
top_k: 5,
reranking_enable: false,
reranking_model: {
reranking_model_name: '',
reranking_provider_name: '',
},
score_threshold_enabled: false,
score_threshold: 0,
},
...overrides,
} as DataSet)
describe('SelectDataSet', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('renders dataset entries, allows selection, and fires onSelect', async () => {
const datasetOne = makeDataset({
id: 'set-1',
name: 'Dataset One',
is_multimodal: true,
indexing_technique: IndexingType.ECONOMICAL,
})
const datasetTwo = makeDataset({
id: 'set-2',
name: 'Hidden Dataset',
embedding_available: false,
provider: 'external',
})
mockUseInfiniteDatasets.mockReturnValue({
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
isLoading: false,
isFetchingNextPage: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
})
const onSelect = vi.fn()
await act(async () => {
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
})
expect(screen.getByText('Dataset One')).toBeInTheDocument()
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
await act(async () => {
fireEvent.click(screen.getByText('Dataset One'))
})
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
await act(async () => {
fireEvent.click(addButton)
})
expect(onSelect).toHaveBeenCalledWith([datasetOne])
})
it('shows empty state when no datasets are available and disables add', async () => {
mockUseInfiniteDatasets.mockReturnValue({
data: { pages: [{ data: [] }] },
isLoading: false,
isFetchingNextPage: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
})
await act(async () => {
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
})
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
})
})

View File

@ -0,0 +1,125 @@
import type { IPromptValuePanelProps } from './index'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useStore } from '@/app/components/app/store'
import ConfigContext from '@/context/debug-configuration'
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
import PromptValuePanel from './index'
vi.mock('@/app/components/app/store', () => ({
useStore: vi.fn(),
}))
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
__esModule: true,
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
<button type="button" onClick={onFeatureBarClick}>
feature bar
</button>
),
}))
const mockSetShowAppConfigureFeaturesModal = vi.fn()
const mockUseStore = vi.mocked(useStore)
const mockSetInputs = vi.fn()
const mockOnSend = vi.fn()
const promptVariables = [
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
] as const
const baseContextValue: any = {
modelModeType: ModelModeType.completion,
modelConfig: {
configs: {
prompt_template: 'prompt template',
prompt_variables: promptVariables,
},
},
setInputs: mockSetInputs,
mode: AppModeEnum.COMPLETION,
isAdvancedMode: false,
completionPromptConfig: {
prompt: { text: 'completion' },
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
},
chatPromptConfig: { prompt: [] },
} as any
const defaultProps: IPromptValuePanelProps = {
appType: AppModeEnum.COMPLETION,
onSend: mockOnSend,
inputs: { textVar: 'initial', boolVar: false },
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
onVisionFilesChange: vi.fn(),
}
const renderPanel = (options: {
context?: Partial<typeof baseContextValue>
props?: Partial<IPromptValuePanelProps>
} = {}) => {
const contextValue = { ...baseContextValue, ...options.context }
const props = { ...defaultProps, ...options.props }
return render(
<ConfigContext.Provider value={contextValue}>
<PromptValuePanel {...props} />
</ConfigContext.Provider>,
)
}
describe('PromptValuePanel', () => {
beforeEach(() => {
mockUseStore.mockImplementation(selector => selector({
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
appSidebarExpand: '',
currentLogModalActiveTab: 'prompt',
showPromptLogModal: false,
showAgentLogModal: false,
setShowPromptLogModal: vi.fn(),
setShowAgentLogModal: vi.fn(),
showMessageLogModal: false,
showAppConfigureFeaturesModal: false,
} as any))
mockSetInputs.mockClear()
mockOnSend.mockClear()
mockSetShowAppConfigureFeaturesModal.mockClear()
})
it('updates inputs, clears values, and triggers run when ready', async () => {
renderPanel()
const textInput = screen.getByPlaceholderText('Text Var')
fireEvent.change(textInput, { target: { value: 'updated' } })
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
fireEvent.click(clearButton)
expect(mockSetInputs).toHaveBeenLastCalledWith({
textVar: '',
boolVar: '',
})
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
expect(runButton).not.toBeDisabled()
fireEvent.click(runButton)
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
})
it('disables run when mode is not completion', () => {
renderPanel({
context: {
mode: AppModeEnum.CHAT,
},
props: {
appType: AppModeEnum.CHAT,
},
})
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
expect(runButton).toBeDisabled()
fireEvent.click(runButton)
expect(mockOnSend).not.toHaveBeenCalled()
})
})

View File

@ -0,0 +1,29 @@
import type { PromptVariable } from '@/models/debug'
import { describe, expect, it } from 'vitest'
import { replaceStringWithValues } from './utils'
const promptVariables: PromptVariable[] = [
{ key: 'user', name: 'User', type: 'string' },
{ key: 'topic', name: 'Topic', type: 'string' },
]
describe('replaceStringWithValues', () => {
it('should replace placeholders when inputs have values', () => {
const template = 'Hello {{user}} talking about {{topic}}'
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
expect(result).toBe('Hello Alice talking about cats')
})
it('should use prompt variable name when value is missing', () => {
const template = 'Hi {{user}} from {{topic}}'
const result = replaceStringWithValues(template, promptVariables, {})
expect(result).toBe('Hi {{User}} from {{Topic}}')
})
it('should leave placeholder untouched when no variable is defined', () => {
const template = 'Unknown {{missing}} placeholder'
const result = replaceStringWithValues(template, promptVariables, {})
expect(result).toBe('Unknown {{missing}} placeholder')
})
})

View File

@ -0,0 +1,162 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useRouter } from 'next/navigation'
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
import { trackEvent } from '@/app/components/base/amplitude'
import { ToastContext } from '@/app/components/base/toast'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import CreateAppModal from './index'
vi.mock('ahooks', () => ({
useDebounceFn: (fn: (...args: any[]) => any) => {
const run = (...args: any[]) => fn(...args)
const cancel = vi.fn()
const flush = vi.fn()
return { run, cancel, flush }
},
useKeyPress: vi.fn(),
useHover: () => false,
}))
vi.mock('next/navigation', () => ({
useRouter: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/service/apps', () => ({
createApp: vi.fn(),
}))
vi.mock('@/utils/app-redirection', () => ({
getRedirection: vi.fn(),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useDocLink: () => () => '/guides',
}))
vi.mock('@/hooks/use-theme', () => ({
__esModule: true,
default: () => ({ theme: 'light' }),
}))
const mockNotify = vi.fn()
const mockUseRouter = vi.mocked(useRouter)
const mockPush = vi.fn()
const mockCreateApp = vi.mocked(createApp)
const mockTrackEvent = vi.mocked(trackEvent)
const mockGetRedirection = vi.mocked(getRedirection)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseAppContext = vi.mocked(useAppContext)
const defaultPlanUsage = {
buildApps: 0,
teamMembers: 0,
annotatedResponse: 0,
documentsUploadQuota: 0,
apiRateLimit: 0,
triggerEvents: 0,
vectorSpace: 0,
}
const renderModal = () => {
const onClose = vi.fn()
const onSuccess = vi.fn()
render(
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
</ToastContext.Provider>,
)
return { onClose, onSuccess }
}
describe('CreateAppModal', () => {
const mockSetItem = vi.fn()
const originalLocalStorage = window.localStorage
beforeEach(() => {
vi.clearAllMocks()
mockUseRouter.mockReturnValue({ push: mockPush } as any)
mockUseProviderContext.mockReturnValue({
plan: {
type: AppModeEnum.ADVANCED_CHAT,
usage: defaultPlanUsage,
total: { ...defaultPlanUsage, buildApps: 1 },
reset: {},
},
enableBilling: true,
} as any)
mockUseAppContext.mockReturnValue({
isCurrentWorkspaceEditor: true,
} as any)
mockSetItem.mockClear()
Object.defineProperty(window, 'localStorage', {
value: {
setItem: mockSetItem,
getItem: vi.fn(),
removeItem: vi.fn(),
clear: vi.fn(),
key: vi.fn(),
length: 0,
},
writable: true,
})
})
afterAll(() => {
Object.defineProperty(window, 'localStorage', {
value: originalLocalStorage,
writable: true,
})
})
it('creates an app, notifies success, and fires callbacks', async () => {
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
mockCreateApp.mockResolvedValue(mockApp as any)
const { onClose, onSuccess } = renderModal()
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
name: 'My App',
description: '',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#FFEAD5',
mode: AppModeEnum.ADVANCED_CHAT,
}))
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
app_mode: AppModeEnum.ADVANCED_CHAT,
description: '',
})
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
expect(onSuccess).toHaveBeenCalled()
expect(onClose).toHaveBeenCalled()
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
})
it('shows error toast when creation fails', async () => {
mockCreateApp.mockRejectedValue(new Error('boom'))
const { onClose } = renderModal()
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
expect(onClose).not.toHaveBeenCalled()
})
})

View File

@ -139,14 +139,14 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
id: item.id,
content: item.answer,
agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files),
feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback
adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback
feedback: item.feedbacks?.find(item => item.from_source === 'user'), // user feedback
adminFeedback: item.feedbacks?.find(item => item.from_source === 'admin'), // admin feedback
feedbackDisabled: false,
isAnswer: true,
message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))),
log: [
...item.message,
...(item.message[item.message.length - 1]?.role !== 'assistant'
...(item.message ?? []),
...(item.message?.[item.message.length - 1]?.role !== 'assistant'
? [
{
role: 'assistant',
@ -165,7 +165,7 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
more: {
time: dayjs.unix(item.created_at).tz(timezone).format(format),
tokens: item.answer_tokens + item.message_tokens,
latency: item.provider_response_latency.toFixed(2),
latency: (item.provider_response_latency ?? 0).toFixed(2),
},
citation: item.metadata?.retriever_resources,
annotation: (() => {

View File

@ -0,0 +1,121 @@
import type { SiteInfo } from '@/models/share'
import { fireEvent, render, screen } from '@testing-library/react'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { act } from 'react'
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
import Embedded from './index'
vi.mock('./style.module.css', () => ({
__esModule: true,
default: {
option: 'option',
active: 'active',
iframeIcon: 'iframeIcon',
scriptsIcon: 'scriptsIcon',
chromePluginIcon: 'chromePluginIcon',
pluginInstallIcon: 'pluginInstallIcon',
},
}))
const mockThemeBuilder = {
buildTheme: vi.fn(),
theme: {
primaryColor: '#123456',
},
}
const mockUseAppContext = vi.fn(() => ({
langGeniusVersionInfo: {
current_env: 'PRODUCTION',
current_version: '',
latest_version: '',
release_date: '',
release_notes: '',
version: '',
can_auto_update: false,
},
}))
vi.mock('copy-to-clipboard', () => ({
__esModule: true,
default: vi.fn(),
}))
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
useThemeContext: () => mockThemeBuilder,
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => mockUseAppContext(),
}))
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
const mockedCopy = vi.mocked(copy)
const siteInfo: SiteInfo = {
title: 'test site',
chat_color_theme: '#000000',
chat_color_theme_inverted: false,
}
const baseProps = {
isShow: true,
siteInfo,
onClose: vi.fn(),
appBaseUrl: 'https://app.example.com',
accessToken: 'token',
className: 'custom-modal',
}
const getCopyButton = () => {
const buttons = screen.getAllByRole('button')
const actionButton = buttons.find(button => button.className.includes('action-btn'))
expect(actionButton).toBeDefined()
return actionButton!
}
describe('Embedded', () => {
afterEach(() => {
vi.clearAllMocks()
mockWindowOpen.mockClear()
})
afterAll(() => {
mockWindowOpen.mockRestore()
})
it('builds theme and copies iframe snippet', async () => {
await act(async () => {
render(<Embedded {...baseProps} />)
})
const actionButton = getCopyButton()
const innerDiv = actionButton.querySelector('div')
act(() => {
fireEvent.click(innerDiv ?? actionButton)
})
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
})
it('opens chrome plugin store link when chrome option selected', async () => {
await act(async () => {
render(<Embedded {...baseProps} />)
})
const optionButtons = document.body.querySelectorAll('[class*="option"]')
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
act(() => {
fireEvent.click(optionButtons[2])
})
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
act(() => {
fireEvent.click(chromeText)
})
expect(mockWindowOpen).toHaveBeenCalledWith(
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
'_blank',
'noopener,noreferrer',
)
})
})

View File

@ -0,0 +1,67 @@
import type { ISavedItemsProps } from './index'
import { fireEvent, render, screen } from '@testing-library/react'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import SavedItems from './index'
vi.mock('copy-to-clipboard', () => ({
__esModule: true,
default: vi.fn(),
}))
vi.mock('next/navigation', () => ({
useParams: () => ({}),
usePathname: () => '/',
}))
const mockCopy = vi.mocked(copy)
const toastNotifySpy = vi.spyOn(Toast, 'notify')
const baseProps: ISavedItemsProps = {
list: [
{ id: '1', answer: 'hello world' },
],
isShowTextToSpeech: true,
onRemove: vi.fn(),
onStartCreateContent: vi.fn(),
}
describe('SavedItems', () => {
beforeEach(() => {
vi.clearAllMocks()
toastNotifySpy.mockClear()
})
it('renders saved answers with metadata and controls', () => {
const { container } = render(<SavedItems {...baseProps} />)
const markdownElement = container.querySelector('.markdown-body')
expect(markdownElement).toBeInTheDocument()
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
const actionButtons = actionArea?.querySelectorAll('button') ?? []
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
})
it('copies content and notifies, and triggers remove callback', () => {
const handleRemove = vi.fn()
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
const actionButtons = actionArea?.querySelectorAll('button') ?? []
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
const copyButton = actionButtons[1]
const deleteButton = actionButtons[2]
fireEvent.click(copyButton)
expect(mockCopy).toHaveBeenCalledWith('hello world')
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
fireEvent.click(deleteButton)
expect(handleRemove).toHaveBeenCalledWith('1')
})
})

View File

@ -0,0 +1,22 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import NoData from './index'
describe('NoData', () => {
it('renders title/description and calls callback when button clicked', () => {
const handleStart = vi.fn()
render(<NoData onStartCreateContent={handleStart} />)
const title = screen.getByText('share.generation.savedNoData.title')
const description = screen.getByText('share.generation.savedNoData.description')
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
expect(button).toBeInTheDocument()
fireEvent.click(button)
expect(handleStart).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,308 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import Avatar from './index'
describe('Avatar', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering tests - verify component renders correctly in different states
describe('Rendering', () => {
it('should render img element with correct alt and src when avatar URL is provided', () => {
const avatarUrl = 'https://example.com/avatar.jpg'
const props = { name: 'John Doe', avatar: avatarUrl }
render(<Avatar {...props} />)
const img = screen.getByRole('img', { name: 'John Doe' })
expect(img).toBeInTheDocument()
expect(img).toHaveAttribute('src', avatarUrl)
})
it('should render fallback div with uppercase initial when avatar is null', () => {
const props = { name: 'alice', avatar: null }
render(<Avatar {...props} />)
expect(screen.queryByRole('img')).not.toBeInTheDocument()
expect(screen.getByText('A')).toBeInTheDocument()
})
})
// Props tests - verify all props are applied correctly
describe('Props', () => {
describe('size prop', () => {
it.each([
{ size: undefined, expected: '30px', label: 'default (30px)' },
{ size: 50, expected: '50px', label: 'custom (50px)' },
])('should apply $label size to img element', ({ size, expected }) => {
const props = { name: 'Test', avatar: 'https://example.com/avatar.jpg', size }
render(<Avatar {...props} />)
expect(screen.getByRole('img')).toHaveStyle({
width: expected,
height: expected,
fontSize: expected,
lineHeight: expected,
})
})
it('should apply size to fallback div when avatar is null', () => {
const props = { name: 'Test', avatar: null, size: 40 }
render(<Avatar {...props} />)
const textElement = screen.getByText('T')
const outerDiv = textElement.parentElement as HTMLElement
expect(outerDiv).toHaveStyle({ width: '40px', height: '40px' })
})
})
describe('className prop', () => {
it('should merge className with default avatar classes on img', () => {
const props = {
name: 'Test',
avatar: 'https://example.com/avatar.jpg',
className: 'custom-class',
}
render(<Avatar {...props} />)
const img = screen.getByRole('img')
expect(img).toHaveClass('custom-class')
expect(img).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
})
it('should merge className with default avatar classes on fallback div', () => {
const props = {
name: 'Test',
avatar: null,
className: 'my-custom-class',
}
render(<Avatar {...props} />)
const textElement = screen.getByText('T')
const outerDiv = textElement.parentElement as HTMLElement
expect(outerDiv).toHaveClass('my-custom-class')
expect(outerDiv).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
})
})
describe('textClassName prop', () => {
it('should apply textClassName to the initial text element', () => {
const props = {
name: 'Test',
avatar: null,
textClassName: 'custom-text-class',
}
render(<Avatar {...props} />)
const textElement = screen.getByText('T')
expect(textElement).toHaveClass('custom-text-class')
expect(textElement).toHaveClass('scale-[0.4]', 'text-center', 'text-white')
})
})
})
// State Management tests - verify useState and useEffect behavior
describe('State Management', () => {
it('should switch to fallback when image fails to load', async () => {
const props = { name: 'John', avatar: 'https://example.com/broken.jpg' }
render(<Avatar {...props} />)
const img = screen.getByRole('img')
fireEvent.error(img)
await waitFor(() => {
expect(screen.queryByRole('img')).not.toBeInTheDocument()
})
expect(screen.getByText('J')).toBeInTheDocument()
})
it('should reset error state when avatar URL changes', async () => {
const initialProps = { name: 'John', avatar: 'https://example.com/broken.jpg' }
const { rerender } = render(<Avatar {...initialProps} />)
const img = screen.getByRole('img')
// First, trigger error
fireEvent.error(img)
await waitFor(() => {
expect(screen.queryByRole('img')).not.toBeInTheDocument()
})
expect(screen.getByText('J')).toBeInTheDocument()
rerender(<Avatar name="John" avatar="https://example.com/new-avatar.jpg" />)
await waitFor(() => {
expect(screen.getByRole('img')).toBeInTheDocument()
})
expect(screen.queryByText('J')).not.toBeInTheDocument()
})
it('should not reset error state if avatar becomes null', async () => {
const initialProps = { name: 'John', avatar: 'https://example.com/broken.jpg' }
const { rerender } = render(<Avatar {...initialProps} />)
// Trigger error
fireEvent.error(screen.getByRole('img'))
await waitFor(() => {
expect(screen.getByText('J')).toBeInTheDocument()
})
rerender(<Avatar name="John" avatar={null} />)
await waitFor(() => {
expect(screen.queryByRole('img')).not.toBeInTheDocument()
})
expect(screen.getByText('J')).toBeInTheDocument()
})
})
// Event Handlers tests - verify onError callback behavior
describe('Event Handlers', () => {
it('should call onError with true when image fails to load', () => {
const onErrorMock = vi.fn()
const props = {
name: 'John',
avatar: 'https://example.com/broken.jpg',
onError: onErrorMock,
}
render(<Avatar {...props} />)
fireEvent.error(screen.getByRole('img'))
expect(onErrorMock).toHaveBeenCalledTimes(1)
expect(onErrorMock).toHaveBeenCalledWith(true)
})
it('should call onError with false when image loads successfully', () => {
const onErrorMock = vi.fn()
const props = {
name: 'John',
avatar: 'https://example.com/avatar.jpg',
onError: onErrorMock,
}
render(<Avatar {...props} />)
fireEvent.load(screen.getByRole('img'))
expect(onErrorMock).toHaveBeenCalledTimes(1)
expect(onErrorMock).toHaveBeenCalledWith(false)
})
it('should not throw when onError is not provided', async () => {
const props = { name: 'John', avatar: 'https://example.com/broken.jpg' }
render(<Avatar {...props} />)
expect(() => fireEvent.error(screen.getByRole('img'))).not.toThrow()
await waitFor(() => {
expect(screen.getByText('J')).toBeInTheDocument()
})
})
})
// Edge Cases tests - verify handling of unusual inputs
describe('Edge Cases', () => {
it('should handle empty string name gracefully', () => {
const props = { name: '', avatar: null }
const { container } = render(<Avatar {...props} />)
// Note: Using querySelector here because empty name produces no visible text,
// making semantic queries (getByRole, getByText) impossible
const textElement = container.querySelector('.text-white') as HTMLElement
expect(textElement).toBeInTheDocument()
expect(textElement.textContent).toBe('')
})
it.each([
{ name: '中文名', expected: '中', label: 'Chinese characters' },
{ name: '123User', expected: '1', label: 'number' },
])('should display first character when name starts with $label', ({ name, expected }) => {
const props = { name, avatar: null }
render(<Avatar {...props} />)
expect(screen.getByText(expected)).toBeInTheDocument()
})
it('should handle empty string avatar as falsy value', () => {
const props = { name: 'Test', avatar: '' as string | null }
render(<Avatar {...props} />)
expect(screen.queryByRole('img')).not.toBeInTheDocument()
expect(screen.getByText('T')).toBeInTheDocument()
})
it('should handle undefined className and textClassName', () => {
const props = { name: 'Test', avatar: null }
render(<Avatar {...props} />)
const textElement = screen.getByText('T')
const outerDiv = textElement.parentElement as HTMLElement
expect(outerDiv).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
})
it.each([
{ size: 0, expected: '0px', label: 'zero' },
{ size: 1000, expected: '1000px', label: 'very large' },
])('should handle $label size value', ({ size, expected }) => {
const props = { name: 'Test', avatar: null, size }
render(<Avatar {...props} />)
const textElement = screen.getByText('T')
const outerDiv = textElement.parentElement as HTMLElement
expect(outerDiv).toHaveStyle({ width: expected, height: expected })
})
})
// Combined props tests - verify props work together correctly
describe('Combined Props', () => {
it('should apply all props correctly when used together', () => {
const onErrorMock = vi.fn()
const props = {
name: 'Test User',
avatar: 'https://example.com/avatar.jpg',
size: 64,
className: 'custom-avatar',
onError: onErrorMock,
}
render(<Avatar {...props} />)
const img = screen.getByRole('img')
expect(img).toHaveAttribute('alt', 'Test User')
expect(img).toHaveAttribute('src', 'https://example.com/avatar.jpg')
expect(img).toHaveStyle({ width: '64px', height: '64px' })
expect(img).toHaveClass('custom-avatar')
// Trigger load to verify onError callback
fireEvent.load(img)
expect(onErrorMock).toHaveBeenCalledWith(false)
})
it('should apply all fallback props correctly when used together', () => {
const props = {
name: 'Fallback User',
avatar: null,
size: 48,
className: 'fallback-custom',
textClassName: 'custom-text-style',
}
render(<Avatar {...props} />)
const textElement = screen.getByText('F')
const outerDiv = textElement.parentElement as HTMLElement
expect(outerDiv).toHaveClass('fallback-custom')
expect(outerDiv).toHaveStyle({ width: '48px', height: '48px' })
expect(textElement).toHaveClass('custom-text-style')
})
})
})

View File

@ -0,0 +1,147 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils'
import { useToastContext } from '@/app/components/base/toast'
import { Plan } from '@/app/components/billing/type'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useProviderContext } from '@/context/provider-context'
import { updateCurrentWorkspace } from '@/service/common'
import CustomWebAppBrand from './index'
vi.mock('@/app/components/base/toast', () => ({
useToastContext: vi.fn(),
}))
vi.mock('@/service/common', () => ({
updateCurrentWorkspace: vi.fn(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('@/app/components/base/image-uploader/utils', () => ({
imageUpload: vi.fn(),
getImageUploadErrorMessage: vi.fn(),
}))
const mockNotify = vi.fn()
const mockUseToastContext = vi.mocked(useToastContext)
const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace)
const mockUseAppContext = vi.mocked(useAppContext)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
const mockImageUpload = vi.mocked(imageUpload)
const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage)
const defaultPlanUsage = {
buildApps: 0,
teamMembers: 0,
annotatedResponse: 0,
documentsUploadQuota: 0,
apiRateLimit: 0,
triggerEvents: 0,
vectorSpace: 0,
}
const renderComponent = () => render(<CustomWebAppBrand />)
describe('CustomWebAppBrand', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseToastContext.mockReturnValue({ notify: mockNotify } as any)
mockUpdateCurrentWorkspace.mockResolvedValue({} as any)
mockUseAppContext.mockReturnValue({
currentWorkspace: {
custom_config: {
replace_webapp_logo: 'https://example.com/replace.png',
remove_webapp_brand: false,
},
},
mutateCurrentWorkspace: vi.fn(),
isCurrentWorkspaceManager: true,
} as any)
mockUseProviderContext.mockReturnValue({
plan: {
type: Plan.professional,
usage: defaultPlanUsage,
total: defaultPlanUsage,
reset: {},
},
enableBilling: false,
} as any)
const systemFeaturesState = {
branding: {
enabled: true,
workspace_logo: 'https://example.com/workspace-logo.png',
},
}
mockUseGlobalPublicStore.mockImplementation(selector => selector ? selector({ systemFeatures: systemFeaturesState } as any) : { systemFeatures: systemFeaturesState })
mockGetImageUploadErrorMessage.mockReturnValue('upload error')
})
it('disables upload controls when the user cannot manage the workspace', () => {
mockUseAppContext.mockReturnValue({
currentWorkspace: {
custom_config: {
replace_webapp_logo: '',
remove_webapp_brand: false,
},
},
mutateCurrentWorkspace: vi.fn(),
isCurrentWorkspaceManager: false,
} as any)
const { container } = renderComponent()
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
expect(fileInput).toBeDisabled()
})
it('toggles remove brand switch and calls the backend + mutate', async () => {
const mutateMock = vi.fn()
mockUseAppContext.mockReturnValue({
currentWorkspace: {
custom_config: {
replace_webapp_logo: '',
remove_webapp_brand: false,
},
},
mutateCurrentWorkspace: mutateMock,
isCurrentWorkspaceManager: true,
} as any)
renderComponent()
const switchInput = screen.getByRole('switch')
fireEvent.click(switchInput)
await waitFor(() => expect(mockUpdateCurrentWorkspace).toHaveBeenCalledWith({
url: '/workspaces/custom-config',
body: { remove_webapp_brand: true },
}))
await waitFor(() => expect(mutateMock).toHaveBeenCalled())
})
it('shows cancel/apply buttons after successful upload and cancels properly', async () => {
mockImageUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }) => {
onProgressCallback(50)
onSuccessCallback({ id: 'new-logo' })
})
const { container } = renderComponent()
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
const testFile = new File(['content'], 'logo.png', { type: 'image/png' })
fireEvent.change(fileInput, { target: { files: [testFile] } })
await waitFor(() => expect(mockImageUpload).toHaveBeenCalled())
await waitFor(() => screen.getByRole('button', { name: 'custom.apply' }))
const cancelButton = screen.getByRole('button', { name: 'common.operation.cancel' })
fireEvent.click(cancelButton)
await waitFor(() => expect(screen.queryByRole('button', { name: 'custom.apply' })).toBeNull())
})
})

View File

@ -1,5 +1,5 @@
import type { FC } from 'react'
import { RiArchive2Line, RiCheckboxCircleLine, RiCloseCircleLine, RiDeleteBinLine, RiDraftLine } from '@remixicon/react'
import { RiArchive2Line, RiCheckboxCircleLine, RiCloseCircleLine, RiDeleteBinLine, RiDraftLine, RiRefreshLine } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
@ -17,6 +17,7 @@ type IBatchActionProps = {
onBatchDelete: () => Promise<void>
onArchive?: () => void
onEditMetadata?: () => void
onBatchReIndex?: () => void
onCancel: () => void
}
@ -28,6 +29,7 @@ const BatchAction: FC<IBatchActionProps> = ({
onArchive,
onBatchDelete,
onEditMetadata,
onBatchReIndex,
onCancel,
}) => {
const { t } = useTranslation()
@ -91,6 +93,16 @@ const BatchAction: FC<IBatchActionProps> = ({
<span className="px-0.5">{t(`${i18nPrefix}.archive`)}</span>
</Button>
)}
{onBatchReIndex && (
<Button
variant="ghost"
className="gap-x-0.5 px-3"
onClick={onBatchReIndex}
>
<RiRefreshLine className="size-4" />
<span className="px-0.5">{t(`${i18nPrefix}.reIndex`)}</span>
</Button>
)}
<Button
variant="ghost"
destructive

View File

@ -26,7 +26,7 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '
import useTimestamp from '@/hooks/use-timestamp'
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
import { DatasourceType } from '@/models/pipeline'
import { useDocumentArchive, useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document'
import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document'
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import { formatNumber } from '@/utils/format'
@ -220,6 +220,7 @@ const DocumentList: FC<IDocumentListProps> = ({
const { mutateAsync: enableDocument } = useDocumentEnable()
const { mutateAsync: disableDocument } = useDocumentDisable()
const { mutateAsync: deleteDocument } = useDocumentDelete()
const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex()
const handleAction = (actionName: DocumentActionType) => {
return async () => {
@ -250,6 +251,22 @@ const DocumentList: FC<IDocumentListProps> = ({
}
}
const handleBatchReIndex = async () => {
const [e] = await asyncRunSafe<CommonResponse>(retryIndexDocument({ datasetId, documentIds: selectedIds }))
if (!e) {
onSelectedIdChange([])
Toast.notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
onUpdate()
}
else {
Toast.notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
}
}
const hasErrorDocumentsSelected = useMemo(() => {
return localDocs.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error')
}, [localDocs, selectedIds])
const getFileExtension = useCallback((fileName: string): string => {
if (!fileName)
return ''
@ -447,6 +464,7 @@ const DocumentList: FC<IDocumentListProps> = ({
onBatchDisable={handleAction(DocumentActionType.disable)}
onBatchDelete={handleAction(DocumentActionType.delete)}
onEditMetadata={showEditModal}
onBatchReIndex={hasErrorDocumentsSelected ? handleBatchReIndex : undefined}
onCancel={() => {
onSelectedIdChange([])
}}

View File

@ -34,7 +34,6 @@ import Records from './components/records'
import ResultItem from './components/result-item'
import ResultItemExternal from './components/result-item-external'
import ModifyRetrievalModal from './modify-retrieval-modal'
import s from './style.module.css'
const limit = 10
@ -115,8 +114,8 @@ const HitTestingPage: FC<Props> = ({ datasetId }: Props) => {
}, [isMobile, setShowRightPanel])
return (
<div className={s.container}>
<div className="flex flex-col px-6 py-3">
<div className="relative flex h-full w-full gap-x-6 overflow-y-auto pl-6">
<div className="flex min-w-0 flex-1 flex-col py-3">
<div className="mb-4 flex flex-col justify-center">
<h1 className="text-base font-semibold text-text-primary">{t('datasetHitTesting.title')}</h1>
<p className="mt-0.5 text-[13px] font-normal leading-4 text-text-tertiary">{t('datasetHitTesting.desc')}</p>
@ -161,7 +160,7 @@ const HitTestingPage: FC<Props> = ({ datasetId }: Props) => {
onClose={hideRightPanel}
footer={null}
>
<div className="flex flex-col pt-3">
<div className="flex min-w-0 flex-1 flex-col pt-3">
{isRetrievalLoading
? (
<div className="flex h-full flex-col rounded-tl-2xl bg-background-body px-4 py-3">

View File

@ -1,43 +0,0 @@
.container {
@apply flex h-full w-full relative overflow-y-auto;
}
.container>div {
@apply flex-1 h-full;
}
.commonIcon {
@apply w-3.5 h-3.5 inline-block align-middle;
background-repeat: no-repeat;
background-position: center center;
background-size: contain;
}
.app_icon {
background-image: url(./assets/grid.svg);
}
.hit_testing_icon {
background-image: url(../documents/assets/target.svg);
}
.plugin_icon {
background-image: url(./assets/plugin.svg);
}
.cardWrapper {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(284px, auto));
grid-gap: 16px;
grid-auto-rows: 216px;
}
.clockWrapper {
border: 0.5px solid #eaecf5;
@apply rounded-lg w-11 h-11 flex justify-center items-center;
}
.clockIcon {
mask-image: url(./assets/clock.svg);
@apply bg-gray-500;
}

View File

@ -0,0 +1,23 @@
'use client'
import { TanStackDevtools } from '@tanstack/react-devtools'
import { formDevtoolsPlugin } from '@tanstack/react-form-devtools'
import { ReactQueryDevtoolsPanel } from '@tanstack/react-query-devtools'
import * as React from 'react'
export function TanStackDevtoolsWrapper() {
return (
<TanStackDevtools
plugins={[
// Query Devtools (Official Plugin)
{
name: 'React Query',
render: () => <ReactQueryDevtoolsPanel />,
},
// Form Devtools (Official Plugin)
formDevtoolsPlugin(),
]}
/>
)
}

View File

@ -3,14 +3,14 @@
import * as Sentry from '@sentry/react'
import { useEffect } from 'react'
const isDevelopment = process.env.NODE_ENV === 'development'
import { IS_DEV } from '@/config'
const SentryInitializer = ({
children,
}: { children: React.ReactElement }) => {
useEffect(() => {
const SENTRY_DSN = document?.body?.getAttribute('data-public-sentry-dsn')
if (!isDevelopment && SENTRY_DSN) {
if (!IS_DEV && SENTRY_DSN) {
Sentry.init({
dsn: SENTRY_DSN,
integrations: [

View File

@ -48,6 +48,7 @@ const EditCustomCollectionModal: FC<Props> = ({
const [editFirst, setEditFirst] = useState(!isAdd)
const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || [])
const [labels, setLabels] = useState<string[]>(payload?.labels || [])
const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd
? {
provider: '',
@ -67,6 +68,15 @@ const EditCustomCollectionModal: FC<Props> = ({
const originalProvider = isEdit ? payload.provider : ''
// Sync customCollection state when payload changes
useEffect(() => {
if (isEdit) {
setCustomCollection(payload)
setParamsSchemas(payload.tools || [])
setLabels(payload.labels || [])
}
}, [isEdit, payload])
const [showEmojiPicker, setShowEmojiPicker] = useState(false)
const emoji = customCollection.icon
const setEmoji = (emoji: Emoji) => {
@ -124,7 +134,6 @@ const EditCustomCollectionModal: FC<Props> = ({
const [currTool, setCurrTool] = useState<CustomParamSchema | null>(null)
const [isShowTestApi, setIsShowTestApi] = useState(false)
const [labels, setLabels] = useState<string[]>(payload?.labels || [])
const handleLabelSelect = (value: string[]) => {
setLabels(value)
}

View File

@ -100,9 +100,28 @@ const ProviderDetail = ({
const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [deleteAction, setDeleteAction] = useState('')
const getCustomProvider = useCallback(async () => {
setIsDetailLoading(true)
const res = await fetchCustomCollection(collection.name)
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
if (res.credentials.api_key_value)
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
}
setCustomCollection({
...res,
labels: collection.labels,
provider: collection.name,
})
setIsDetailLoading(false)
}, [collection.labels, collection.name])
const doUpdateCustomToolCollection = async (data: CustomCollectionBackend) => {
await updateCustomCollection(data)
onRefreshData()
await getCustomProvider()
// Use fresh data from form submission to avoid race condition with collection.labels
setCustomCollection(prev => prev ? { ...prev, labels: data.labels } : null)
Toast.notify({
type: 'success',
message: t('common.api.actionSuccess'),
@ -118,20 +137,6 @@ const ProviderDetail = ({
})
setIsShowEditCustomCollectionModal(false)
}
const getCustomProvider = useCallback(async () => {
setIsDetailLoading(true)
const res = await fetchCustomCollection(collection.name)
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
if (res.credentials.api_key_value)
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
}
setCustomCollection({
...res,
labels: collection.labels,
provider: collection.name,
})
setIsDetailLoading(false)
}, [collection.labels, collection.name])
// workflow provider
const [isShowEditWorkflowToolModal, setIsShowEditWorkflowToolModal] = useState(false)
const getWorkflowToolProvider = useCallback(async () => {

View File

@ -61,7 +61,7 @@ const VersionHistoryButton: FC<VersionHistoryButtonProps> = ({
>
<Button
className={cn(
'p-2 rounded-lg border border-transparent',
'rounded-lg border border-transparent p-2',
theme === 'dark' && 'border-black/5 bg-white/10 backdrop-blur-sm',
)}
onClick={handleViewVersionHistory}

View File

@ -35,6 +35,7 @@ import ReactFlow, {
useReactFlow,
useStoreApi,
} from 'reactflow'
import { IS_DEV } from '@/config'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import {
useAllBuiltInTools,
@ -361,7 +362,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
}
}, [schemaTypeDefinitions, fetchInspectVars, isLoadedVars, vars, customTools, buildInTools, workflowTools, mcpTools, dataSourceList])
if (process.env.NODE_ENV === 'development') {
if (IS_DEV) {
store.getState().onError = (code, message) => {
if (code === '002')
return

View File

@ -1,68 +0,0 @@
'use client'
import BaseForm from '../components/base/form/form-scenarios/base'
import { BaseFieldType } from '../components/base/form/form-scenarios/base/types'
export default function Page() {
return (
<div className="flex h-screen w-full items-center justify-center p-20">
<div className="w-[400px] rounded-lg border border-components-panel-border bg-components-panel-bg">
<BaseForm
initialData={{
type: 'option_1',
variable: 'test',
label: 'Test',
maxLength: 48,
required: true,
}}
configurations={[
{
type: BaseFieldType.textInput,
variable: 'variable',
label: 'Variable',
required: true,
showConditions: [],
},
{
type: BaseFieldType.textInput,
variable: 'label',
label: 'Label',
required: true,
showConditions: [],
},
{
type: BaseFieldType.numberInput,
variable: 'maxLength',
label: 'Max Length',
required: true,
showConditions: [],
max: 100,
min: 1,
},
{
type: BaseFieldType.checkbox,
variable: 'required',
label: 'Required',
required: true,
showConditions: [],
},
{
type: BaseFieldType.select,
variable: 'type',
label: 'Type',
required: true,
showConditions: [],
options: [
{ label: 'Option 1', value: 'option_1' },
{ label: 'Option 2', value: 'option_2' },
{ label: 'Option 3', value: 'option_3' },
],
},
]}
onSubmit={(value) => {
console.log('onSubmit', value)
}}
/>
</div>
</div>
)
}

View File

@ -1,6 +1,8 @@
import type { Viewport } from 'next'
import { ThemeProvider } from 'next-themes'
import dynamic from 'next/dynamic'
import { Instrument_Serif } from 'next/font/google'
import { IS_DEV } from '@/config'
import GlobalPublicStoreProvider from '@/context/global-public-context'
import { TanstackQueryInitializer } from '@/context/query-client'
import { getLocaleOnServer } from '@/i18n-config/server'
@ -8,12 +10,15 @@ import { DatasetAttr } from '@/types/feature'
import { cn } from '@/utils/classnames'
import BrowserInitializer from './components/browser-initializer'
import I18nServer from './components/i18n-server'
import { ReactScan } from './components/react-scan'
import SentryInitializer from './components/sentry-initializer'
import RoutePrefixHandle from './routePrefixHandle'
import './styles/globals.css'
import './styles/markdown.scss'
const ReactScan = IS_DEV
? dynamic(() => import('./components/react-scan').then(m => m.ReactScan), { ssr: false })
: () => null
export const viewport: Viewport = {
width: 'device-width',
initialScale: 1,

View File

@ -2,7 +2,14 @@
import type { FC, PropsWithChildren } from 'react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { ReactQueryDevtools } from '@tanstack/react-query-devtools'
import { lazy, Suspense } from 'react'
import { IS_DEV } from '@/config'
const TanStackDevtoolsWrapper = lazy(() =>
import('@/app/components/devtools').then(module => ({
default: module.TanStackDevtoolsWrapper,
})),
)
const STALE_TIME = 1000 * 60 * 30 // 30 minutes
@ -19,7 +26,11 @@ export const TanstackQueryInitializer: FC<PropsWithChildren> = (props) => {
return (
<QueryClientProvider client={client}>
{children}
<ReactQueryDevtools initialIsOpen={false} />
{IS_DEV && (
<Suspense fallback={null}>
<TanStackDevtoolsWrapper />
</Suspense>
)}
</QueryClientProvider>
)
}

View File

@ -170,6 +170,7 @@ const translation = {
enable: 'Enable',
disable: 'Disable',
archive: 'Archive',
reIndex: 'Re-index',
delete: 'Delete',
cancel: 'Cancel',
},

View File

@ -170,6 +170,7 @@ const translation = {
enable: '启用',
disable: '禁用',
archive: '归档',
reIndex: '重新索引',
delete: '删除',
cancel: '取消',
},

View File

@ -3,7 +3,7 @@
"type": "module",
"version": "1.11.2",
"private": true,
"packageManager": "pnpm@10.26.1+sha512.664074abc367d2c9324fdc18037097ce0a8f126034160f709928e9e9f95d98714347044e5c3164d65bd5da6c59c6be362b107546292a8eecb7999196e5ce58fa",
"packageManager": "pnpm@10.26.2+sha512.0e308ff2005fc7410366f154f625f6631ab2b16b1d2e70238444dd6ae9d630a8482d92a451144debc492416896ed16f7b114a86ec68b8404b2443869e68ffda6",
"engines": {
"node": ">=v22.11.0"
},
@ -71,7 +71,6 @@
"@tailwindcss/typography": "^0.5.19",
"@tanstack/react-form": "^1.23.7",
"@tanstack/react-query": "^5.90.5",
"@tanstack/react-query-devtools": "^5.90.2",
"abcjs": "^6.5.2",
"ahooks": "^3.9.5",
"class-variance-authority": "^0.7.1",
@ -134,7 +133,7 @@
"remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.1",
"remark-math": "^6.0.0",
"scheduler": "^0.26.0",
"scheduler": "^0.27.0",
"semver": "^7.7.3",
"sharp": "^0.33.5",
"sortablejs": "^1.15.6",
@ -164,6 +163,9 @@
"@storybook/addon-themes": "9.1.13",
"@storybook/nextjs": "9.1.13",
"@storybook/react": "9.1.13",
"@tanstack/react-devtools": "^0.9.0",
"@tanstack/react-form-devtools": "^0.2.9",
"@tanstack/react-query-devtools": "^5.90.2",
"@testing-library/dom": "^10.4.1",
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.0",

View File

@ -129,9 +129,6 @@ importers:
'@tanstack/react-query':
specifier: ^5.90.5
version: 5.90.12(react@19.2.3)
'@tanstack/react-query-devtools':
specifier: ^5.90.2
version: 5.91.1(@tanstack/react-query@5.90.12(react@19.2.3))(react@19.2.3)
abcjs:
specifier: ^6.5.2
version: 6.5.2
@ -319,8 +316,8 @@ importers:
specifier: ^6.0.0
version: 6.0.0
scheduler:
specifier: ^0.26.0
version: 0.26.0
specifier: ^0.27.0
version: 0.27.0
semver:
specifier: ^7.7.3
version: 7.7.3
@ -341,7 +338,7 @@ importers:
version: 7.0.19
use-context-selector:
specifier: ^2.0.0
version: 2.0.0(react@19.2.3)(scheduler@0.26.0)
version: 2.0.0(react@19.2.3)(scheduler@0.27.0)
uuid:
specifier: ^10.0.0
version: 10.0.0
@ -403,6 +400,15 @@ importers:
'@storybook/react':
specifier: 9.1.13
version: 9.1.13(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)
'@tanstack/react-devtools':
specifier: ^0.9.0
version: 0.9.0(@types/react-dom@19.2.3(@types/react@19.2.7))(@types/react@19.2.7)(csstype@3.2.3)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(solid-js@1.9.10)
'@tanstack/react-form-devtools':
specifier: ^0.2.9
version: 0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
'@tanstack/react-query-devtools':
specifier: ^5.90.2
version: 5.91.1(@tanstack/react-query@5.90.12(react@19.2.3))(react@19.2.3)
'@testing-library/dom':
specifier: ^10.4.1
version: 10.4.1
@ -3169,6 +3175,36 @@ packages:
resolution: {integrity: sha512-t09vSN3MdfsyCHoFcTRCH/iUtG7OJ0CsjzB8cjAmKc/va/kIgeDI/TxsigdncE/4be734m0cvIYwNaV4i2XqAw==}
engines: {node: '>=10'}
'@solid-primitives/event-listener@2.4.3':
resolution: {integrity: sha512-h4VqkYFv6Gf+L7SQj+Y6puigL/5DIi7x5q07VZET7AWcS+9/G3WfIE9WheniHWJs51OEkRB43w6lDys5YeFceg==}
peerDependencies:
solid-js: ^1.6.12
'@solid-primitives/keyboard@1.3.3':
resolution: {integrity: sha512-9dQHTTgLBqyAI7aavtO+HnpTVJgWQA1ghBSrmLtMu1SMxLPDuLfuNr+Tk5udb4AL4Ojg7h9JrKOGEEDqsJXWJA==}
peerDependencies:
solid-js: ^1.6.12
'@solid-primitives/resize-observer@2.1.3':
resolution: {integrity: sha512-zBLje5E06TgOg93S7rGPldmhDnouNGhvfZVKOp+oG2XU8snA+GoCSSCz1M+jpNAg5Ek2EakU5UVQqL152WmdXQ==}
peerDependencies:
solid-js: ^1.6.12
'@solid-primitives/rootless@1.5.2':
resolution: {integrity: sha512-9HULb0QAzL2r47CCad0M+NKFtQ+LrGGNHZfteX/ThdGvKIg2o2GYhBooZubTCd/RTu2l2+Nw4s+dEfiDGvdrrQ==}
peerDependencies:
solid-js: ^1.6.12
'@solid-primitives/static-store@0.1.2':
resolution: {integrity: sha512-ReK+5O38lJ7fT+L6mUFvUr6igFwHBESZF+2Ug842s7fvlVeBdIVEdTCErygff6w7uR6+jrr7J8jQo+cYrEq4Iw==}
peerDependencies:
solid-js: ^1.6.12
'@solid-primitives/utils@6.3.2':
resolution: {integrity: sha512-hZ/M/qr25QOCcwDPOHtGjxTD8w2mNyVAYvcfgwzBHq2RwNqHNdDNsMZYap20+ruRwW4A3Cdkczyoz0TSxLCAPQ==}
peerDependencies:
solid-js: ^1.6.12
'@standard-schema/spec@1.1.0':
resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==}
@ -3308,13 +3344,67 @@ packages:
peerDependencies:
tailwindcss: '>=3.0.0 || insiders || >=4.0.0-alpha.20 || >=4.0.0-beta.1'
'@tanstack/devtools-client@0.0.5':
resolution: {integrity: sha512-hsNDE3iu4frt9cC2ppn1mNRnLKo2uc1/1hXAyY9z4UYb+o40M2clFAhiFoo4HngjfGJDV3x18KVVIq7W4Un+zA==}
engines: {node: '>=18'}
'@tanstack/devtools-event-bus@0.4.0':
resolution: {integrity: sha512-1t+/csFuDzi+miDxAOh6Xv7VDE80gJEItkTcAZLjV5MRulbO/W8ocjHLI2Do/p2r2/FBU0eKCRTpdqvXaYoHpQ==}
engines: {node: '>=18'}
'@tanstack/devtools-event-client@0.3.5':
resolution: {integrity: sha512-RL1f5ZlfZMpghrCIdzl6mLOFLTuhqmPNblZgBaeKfdtk5rfbjykurv+VfYydOFXj0vxVIoA2d/zT7xfD7Ph8fw==}
engines: {node: '>=18'}
'@tanstack/devtools-event-client@0.4.0':
resolution: {integrity: sha512-RPfGuk2bDZgcu9bAJodvO2lnZeHuz4/71HjZ0bGb/SPg8+lyTA+RLSKQvo7fSmPSi8/vcH3aKQ8EM9ywf1olaw==}
engines: {node: '>=18'}
'@tanstack/devtools-ui@0.4.4':
resolution: {integrity: sha512-5xHXFyX3nom0UaNfiOM92o6ziaHjGo3mcSGe2HD5Xs8dWRZNpdZ0Smd0B9ddEhy0oB+gXyMzZgUJb9DmrZV0Mg==}
engines: {node: '>=18'}
peerDependencies:
solid-js: '>=1.9.7'
'@tanstack/devtools-utils@0.0.9':
resolution: {integrity: sha512-tCObM6wbEjuHeGNs3JDhrqBhoMxpJpVuVIg5Kc33EmUI1ZO7KLpC1277Qf6AmSWy3aVOreGwn3y5bJzxmAJNXg==}
engines: {node: '>=18'}
peerDependencies:
'@types/react': ~19.2.7
react: '>=17.0.0'
solid-js: '>=1.9.7'
vue: '>=3.2.0'
peerDependenciesMeta:
'@types/react':
optional: true
react:
optional: true
solid-js:
optional: true
vue:
optional: true
'@tanstack/devtools@0.10.1':
resolution: {integrity: sha512-1gtPmCDXV4Pl1nVtoqwjV0tc4E9GMuFtlkBX1Lz1KfqI3W9JojT5YsVifOQ/g8BTQ5w5+tyIANwHU7WYgLq/MQ==}
engines: {node: '>=18'}
peerDependencies:
solid-js: '>=1.9.7'
'@tanstack/form-core@1.27.1':
resolution: {integrity: sha512-hPM+0tUnZ2C2zb2TE1lar1JJ0S0cbnQHlUwFcCnVBpMV3rjtUzkoM766gUpWrlmTGCzNad0GbJ0aTxVsjT6J8g==}
'@tanstack/form-core@1.27.6':
resolution: {integrity: sha512-1C4PUpOcCpivddKxtAeqdeqncxnPKiPpTVDRknDExCba+6zCsAjxgL+p3qYA3hu+EFyUAdW71rU+uqYbEa7qqA==}
'@tanstack/form-devtools@0.2.9':
resolution: {integrity: sha512-KOJiwvlFPsHeuWXvHUXRVdciXG1OPhg1c476MsLre0YLdaw1jeMlDYSlqq7sdEULX+2Sg/lhNpX86QbQuxzd2A==}
peerDependencies:
solid-js: '>=1.9.9'
'@tanstack/pacer-lite@0.1.1':
resolution: {integrity: sha512-y/xtNPNt/YeyoVxE/JCx+T7yjEzpezmbb+toK8DDD1P4m7Kzs5YR956+7OKexG3f8aXgC3rLZl7b1V+yNUSy5w==}
engines: {node: '>=18'}
'@tanstack/pacer@0.15.4':
resolution: {integrity: sha512-vGY+CWsFZeac3dELgB6UZ4c7OacwsLb8hvL2gLS6hTgy8Fl0Bm/aLokHaeDIP+q9F9HUZTnp360z9uv78eg8pg==}
engines: {node: '>=18'}
@ -3325,6 +3415,20 @@ packages:
'@tanstack/query-devtools@5.91.1':
resolution: {integrity: sha512-l8bxjk6BMsCaVQH6NzQEE/bEgFy1hAs5qbgXl0xhzezlaQbPk6Mgz9BqEg2vTLPOHD8N4k+w/gdgCbEzecGyNg==}
'@tanstack/react-devtools@0.9.0':
resolution: {integrity: sha512-Lq0svXOTG5N61SHgx8F0on6zz2GB0kmFjN/yyfNLrJyRgJ+U3jYFRd9ti3uBPABsXzHQMHYYujnTXrOYp/OaUg==}
engines: {node: '>=18'}
peerDependencies:
'@types/react': ~19.2.7
'@types/react-dom': ~19.2.3
react: '>=16.8'
react-dom: '>=16.8'
'@tanstack/react-form-devtools@0.2.9':
resolution: {integrity: sha512-wg0xrcVY8evIFGVHrnl9s+/9ENzuVbqv5Ru4HyAJjjL4uECtl6KdDJsi0lZdOyoM1UYEQoVdcN8jfBbxkA3q1g==}
peerDependencies:
react: ^17.0.0 || ^18.0.0 || ^19.0.0
'@tanstack/react-form@1.27.1':
resolution: {integrity: sha512-HKP0Ew2ae9AL5vU1PkJ+oAC2p+xBtA905u0fiNLzlfn1vLkBxenfg5L6TOA+rZITHpQsSo10tqwc5Yw6qn8Mpg==}
peerDependencies:
@ -3594,6 +3698,9 @@ packages:
'@types/node@20.19.26':
resolution: {integrity: sha512-0l6cjgF0XnihUpndDhk+nyD3exio3iKaYROSgvh/qSevPXax3L8p5DBRFjbvalnwatGgHEQn2R88y2fA3g4irg==}
'@types/node@20.19.27':
resolution: {integrity: sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==}
'@types/papaparse@5.5.1':
resolution: {integrity: sha512-esEO+VISsLIyE+JZBmb89NzsYYbpwV8lmv2rPo6oX5y9KhBaIP7hhHgjuTut54qjdKVMufTEcrh5fUl9+58huw==}
@ -5629,6 +5736,11 @@ packages:
globrex@0.1.2:
resolution: {integrity: sha512-uHJgbwAMwNFf5mLst7IWLNg14x1CkeqglJb/K3doi4dw6q2IvAAmM/Y81kevy83wP+Sst+nutFTYOGg3d1lsxg==}
goober@2.1.18:
resolution: {integrity: sha512-2vFqsaDVIT9Gz7N6kAL++pLpp41l3PfDuusHcjnGLfR6+huZkl6ziX+zgVC3ZxpqWhzH6pyDdGrCeDhMIvwaxw==}
peerDependencies:
csstype: ^3.0.10
got@11.8.6:
resolution: {integrity: sha512-6tfZ91bOr7bOXnK7PRDCGBLa1H4U080YHNaAQ2KsMGlLEzRbk44nsZF2E1IeRc3vtJHPVbKCYgdFbaGO2ljd8g==}
engines: {node: '>=10.19.0'}
@ -6203,8 +6315,8 @@ packages:
lexical@0.38.2:
resolution: {integrity: sha512-JJmfsG3c4gwBHzUGffbV7ifMNkKAWMCnYE3xJl87gty7hjyV5f3xq7eqTjP5HFYvO4XpjJvvWO2/djHp5S10tw==}
lib0@0.2.115:
resolution: {integrity: sha512-noaW4yNp6hCjOgDnWWxW0vGXE3kZQI5Kqiwz+jIWXavI9J9WyfJ9zjsbQlQlgjIbHBrvlA/x3TSIXBUJj+0L6g==}
lib0@0.2.116:
resolution: {integrity: sha512-4zsosjzmt33rx5XjmFVYUAeLNh+BTeDTiwGdLt4muxiir2btsc60Nal0EvkvDRizg+pnlK1q+BtYi7M+d4eStw==}
engines: {node: '>=16'}
hasBin: true
@ -7641,9 +7753,6 @@ packages:
resolution: {integrity: sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==}
engines: {node: '>=v12.22.7'}
scheduler@0.26.0:
resolution: {integrity: sha512-NlHwttCI/l5gCPR3D1nNXtWABUmBwvZpEQiD4IXSbIDq8BzLIK/7Ir5gTFSGZDUu37K5cMNp0hFtzO38sC7gWA==}
scheduler@0.27.0:
resolution: {integrity: sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q==}
@ -7687,6 +7796,16 @@ packages:
serialize-javascript@6.0.2:
resolution: {integrity: sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==}
seroval-plugins@1.3.3:
resolution: {integrity: sha512-16OL3NnUBw8JG1jBLUoZJsLnQq0n5Ua6aHalhJK4fMQkz1lqR7Osz1sA30trBtd9VUDc2NgkuRCn8+/pBwqZ+w==}
engines: {node: '>=10'}
peerDependencies:
seroval: ^1.0
seroval@1.3.2:
resolution: {integrity: sha512-RbcPH1n5cfwKrru7v7+zrZvjLurgHhGyso3HTyGtRivGWgYjbOmGuivCQaORNELjNONoK35nj28EoWul9sb1zQ==}
engines: {node: '>=10'}
setimmediate@1.0.5:
resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==}
@ -7753,6 +7872,9 @@ packages:
resolution: {integrity: sha512-QlaZEqcAH3/RtNyet1IPIYPsEWAaYyXXv1Krsi+1L/QHppjX4Ifm8MQsBISz9vE8cHicIq3clogsheili5vhaQ==}
engines: {node: '>= 18'}
solid-js@1.9.10:
resolution: {integrity: sha512-Coz956cos/EPDlhs6+jsdTxKuJDPT7B5SVIWgABwROyxjY7Xbr8wkzD68Et+NxnV7DLJ3nJdAC2r9InuV/4Jew==}
sortablejs@1.15.6:
resolution: {integrity: sha512-aNfiuwMEpfBM/CN6LY0ibyhxPfPbyFeBTYJKCvzkJ2GkUpazIt3H+QIPAMHwqQ7tMKaHz1Qj+rJJCqljnf4p3A==}
@ -8589,6 +8711,7 @@ packages:
whatwg-encoding@3.1.1:
resolution: {integrity: sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==}
engines: {node: '>=18'}
deprecated: Use @exodus/bytes instead for a more spec-conformant and faster implementation
whatwg-mimetype@3.0.0:
resolution: {integrity: sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==}
@ -11539,6 +11662,40 @@ snapshots:
'@sindresorhus/is@4.6.0': {}
'@solid-primitives/event-listener@2.4.3(solid-js@1.9.10)':
dependencies:
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
solid-js: 1.9.10
'@solid-primitives/keyboard@1.3.3(solid-js@1.9.10)':
dependencies:
'@solid-primitives/event-listener': 2.4.3(solid-js@1.9.10)
'@solid-primitives/rootless': 1.5.2(solid-js@1.9.10)
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
solid-js: 1.9.10
'@solid-primitives/resize-observer@2.1.3(solid-js@1.9.10)':
dependencies:
'@solid-primitives/event-listener': 2.4.3(solid-js@1.9.10)
'@solid-primitives/rootless': 1.5.2(solid-js@1.9.10)
'@solid-primitives/static-store': 0.1.2(solid-js@1.9.10)
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
solid-js: 1.9.10
'@solid-primitives/rootless@1.5.2(solid-js@1.9.10)':
dependencies:
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
solid-js: 1.9.10
'@solid-primitives/static-store@0.1.2(solid-js@1.9.10)':
dependencies:
'@solid-primitives/utils': 6.3.2(solid-js@1.9.10)
solid-js: 1.9.10
'@solid-primitives/utils@6.3.2(solid-js@1.9.10)':
dependencies:
solid-js: 1.9.10
'@standard-schema/spec@1.1.0': {}
'@standard-schema/utils@0.3.0': {}
@ -11766,14 +11923,84 @@ snapshots:
postcss-selector-parser: 6.0.10
tailwindcss: 3.4.18(tsx@4.21.0)(yaml@2.8.2)
'@tanstack/devtools-client@0.0.5':
dependencies:
'@tanstack/devtools-event-client': 0.4.0
'@tanstack/devtools-event-bus@0.4.0':
dependencies:
ws: 8.18.3
transitivePeerDependencies:
- bufferutil
- utf-8-validate
'@tanstack/devtools-event-client@0.3.5': {}
'@tanstack/devtools-event-client@0.4.0': {}
'@tanstack/devtools-ui@0.4.4(csstype@3.2.3)(solid-js@1.9.10)':
dependencies:
clsx: 2.1.1
goober: 2.1.18(csstype@3.2.3)
solid-js: 1.9.10
transitivePeerDependencies:
- csstype
'@tanstack/devtools-utils@0.0.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)':
dependencies:
'@tanstack/devtools-ui': 0.4.4(csstype@3.2.3)(solid-js@1.9.10)
optionalDependencies:
'@types/react': 19.2.7
react: 19.2.3
solid-js: 1.9.10
transitivePeerDependencies:
- csstype
'@tanstack/devtools@0.10.1(csstype@3.2.3)(solid-js@1.9.10)':
dependencies:
'@solid-primitives/event-listener': 2.4.3(solid-js@1.9.10)
'@solid-primitives/keyboard': 1.3.3(solid-js@1.9.10)
'@solid-primitives/resize-observer': 2.1.3(solid-js@1.9.10)
'@tanstack/devtools-client': 0.0.5
'@tanstack/devtools-event-bus': 0.4.0
'@tanstack/devtools-ui': 0.4.4(csstype@3.2.3)(solid-js@1.9.10)
clsx: 2.1.1
goober: 2.1.18(csstype@3.2.3)
solid-js: 1.9.10
transitivePeerDependencies:
- bufferutil
- csstype
- utf-8-validate
'@tanstack/form-core@1.27.1':
dependencies:
'@tanstack/devtools-event-client': 0.3.5
'@tanstack/pacer': 0.15.4
'@tanstack/store': 0.7.7
'@tanstack/form-core@1.27.6':
dependencies:
'@tanstack/devtools-event-client': 0.4.0
'@tanstack/pacer-lite': 0.1.1
'@tanstack/store': 0.7.7
'@tanstack/form-devtools@0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)':
dependencies:
'@tanstack/devtools-ui': 0.4.4(csstype@3.2.3)(solid-js@1.9.10)
'@tanstack/devtools-utils': 0.0.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
'@tanstack/form-core': 1.27.6
clsx: 2.1.1
dayjs: 1.11.19
goober: 2.1.18(csstype@3.2.3)
solid-js: 1.9.10
transitivePeerDependencies:
- '@types/react'
- csstype
- react
- vue
'@tanstack/pacer-lite@0.1.1': {}
'@tanstack/pacer@0.15.4':
dependencies:
'@tanstack/devtools-event-client': 0.3.5
@ -11783,6 +12010,30 @@ snapshots:
'@tanstack/query-devtools@5.91.1': {}
'@tanstack/react-devtools@0.9.0(@types/react-dom@19.2.3(@types/react@19.2.7))(@types/react@19.2.7)(csstype@3.2.3)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(solid-js@1.9.10)':
dependencies:
'@tanstack/devtools': 0.10.1(csstype@3.2.3)(solid-js@1.9.10)
'@types/react': 19.2.7
'@types/react-dom': 19.2.3(@types/react@19.2.7)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
transitivePeerDependencies:
- bufferutil
- csstype
- solid-js
- utf-8-validate
'@tanstack/react-form-devtools@0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)':
dependencies:
'@tanstack/devtools-utils': 0.0.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
'@tanstack/form-devtools': 0.2.9(@types/react@19.2.7)(csstype@3.2.3)(react@19.2.3)(solid-js@1.9.10)
react: 19.2.3
transitivePeerDependencies:
- '@types/react'
- csstype
- solid-js
- vue
'@tanstack/react-form@1.27.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@tanstack/form-core': 1.27.1
@ -12091,6 +12342,11 @@ snapshots:
dependencies:
undici-types: 6.21.0
'@types/node@20.19.27':
dependencies:
undici-types: 6.21.0
optional: true
'@types/papaparse@5.5.1':
dependencies:
'@types/node': 18.15.0
@ -14486,6 +14742,10 @@ snapshots:
globrex@0.1.2: {}
goober@2.1.18(csstype@3.2.3):
dependencies:
csstype: 3.2.3
got@11.8.6:
dependencies:
'@sindresorhus/is': 4.6.0
@ -14512,7 +14772,7 @@ snapshots:
happy-dom@20.0.11:
dependencies:
'@types/node': 20.19.26
'@types/node': 20.19.27
'@types/whatwg-mimetype': 3.0.2
whatwg-mimetype: 3.0.0
optional: true
@ -15127,7 +15387,7 @@ snapshots:
lexical@0.38.2: {}
lib0@0.2.115:
lib0@0.2.116:
dependencies:
isomorphic.js: 0.2.5
@ -17044,8 +17304,6 @@ snapshots:
dependencies:
xmlchars: 2.2.0
scheduler@0.26.0: {}
scheduler@0.27.0: {}
schema-utils@2.7.1:
@ -17089,6 +17347,12 @@ snapshots:
dependencies:
randombytes: 2.1.0
seroval-plugins@1.3.3(seroval@1.3.2):
dependencies:
seroval: 1.3.2
seroval@1.3.2: {}
setimmediate@1.0.5: {}
sha.js@2.4.12:
@ -17203,6 +17467,12 @@ snapshots:
smol-toml@1.5.2: {}
solid-js@1.9.10:
dependencies:
csstype: 3.2.3
seroval: 1.3.2
seroval-plugins: 1.3.3(seroval@1.3.2)
sortablejs@1.15.6: {}
source-list-map@2.0.1: {}
@ -17734,10 +18004,10 @@ snapshots:
optionalDependencies:
'@types/react': 19.2.7
use-context-selector@2.0.0(react@19.2.3)(scheduler@0.26.0):
use-context-selector@2.0.0(react@19.2.3)(scheduler@0.27.0):
dependencies:
react: 19.2.3
scheduler: 0.26.0
scheduler: 0.27.0
use-isomorphic-layout-effect@1.2.1(@types/react@19.2.7)(react@19.2.3):
dependencies:
@ -18202,7 +18472,7 @@ snapshots:
yjs@13.6.27:
dependencies:
lib0: 0.2.115
lib0: 0.2.116
yocto-queue@0.1.0: {}

View File

@ -7,7 +7,7 @@ import {
} from '@tanstack/react-query'
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
import { DocumentActionType } from '@/models/datasets'
import { del, get, patch } from '../base'
import { del, get, patch, post } from '../base'
import { pauseDocIndexing, resumeDocIndexing } from '../datasets'
import { useInvalid } from '../use-base'
@ -163,3 +163,15 @@ export const useDocumentResume = () => {
},
})
}
export const useDocumentBatchRetryIndex = () => {
return useMutation({
mutationFn: ({ datasetId, documentIds }: { datasetId: string, documentIds: string[] }) => {
return post<CommonResponse>(`/datasets/${datasetId}/retry`, {
body: {
document_ids: documentIds,
},
})
},
})
}