mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/log-formatter
This commit is contained in:
commit
2f67c5aa75
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
|
|
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
|
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -10,19 +8,19 @@ from controllers.console import console_ns
|
|||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUID
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -10,10 +12,20 @@ from models import TenantAccountRole
|
|||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class LoadBalancingCredentialPayload(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict[str, object]
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||
)
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# validate model load balancing credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
|
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
|
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||
)
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# validate model load balancing config credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
|
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
config_id=config_id,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import io
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
|
|
@ -17,6 +18,7 @@ from controllers.console.wraps import (
|
|||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
|
|
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
|||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
|
|
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
# 1) Create provider in a short transaction (no network I/O inside)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication=authentication,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
# Perform network I/O outside any DB session to avoid holding locks.
|
||||
try:
|
||||
reconnect = MCPToolManageService.reconnect_with_url(
|
||||
server_url=args["server_url"],
|
||||
headers=args.get("headers") or {},
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
)
|
||||
# Update just-created provider with authed/tools in a new short transaction
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||
db_provider.authed = reconnect.authed
|
||||
db_provider.tools = reconnect.tools
|
||||
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||
except Exception:
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
# Final cache invalidation to ensure list views are up to date
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
|
|||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
|
|
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _, dataset_id):
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
|
|
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
|
|
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
|
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
def delete(self, _):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
|
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
|
|
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
|
|
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
|
|
|||
|
|
@ -255,7 +255,10 @@ class PGVector(BaseVector):
|
|||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
|
||||
if not cur.fetchone():
|
||||
cur.execute("CREATE EXTENSION vector")
|
||||
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# PG hnsw index only support 2000 dimension or less
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
|||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
|
|
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
|
|||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self._process_metadata_filter_func(
|
||||
self.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition"), # type: ignore
|
||||
filter.get("metadata_name"), # type: ignore
|
||||
|
|
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
|
|||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
filters = self.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
|
|
@ -1168,8 +1168,9 @@ class DatasetRetrieval:
|
|||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
@classmethod
|
||||
def process_metadata_filter_func(
|
||||
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
|
@ -1218,6 +1219,20 @@ class DatasetRetrieval:
|
|||
|
||||
case "≥" | ">=":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case "in" | "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
# `field in []` is False, `field not in []` is True
|
||||
filters.append(literal(condition == "not in"))
|
||||
else:
|
||||
op = json_field.in_ if condition == "in" else json_field.notin_
|
||||
filters.append(op(value_list))
|
||||
case _:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
|
|
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self._process_metadata_filter_func(
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
|
|
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
|
|
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -110,5 +110,5 @@ class EnterpriseService:
|
|||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
|
||||
body = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
|
|
|||
|
|
@ -319,8 +319,14 @@ class MCPToolManageService:
|
|||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
# Update database with retrieved tools (ensure description is a non-null string)
|
||||
tools_payload = []
|
||||
for tool in tools:
|
||||
data = tool.model_dump()
|
||||
if data.get("description") is None:
|
||||
data["description"] = ""
|
||||
tools_payload.append(data)
|
||||
db_provider.tools = json.dumps(tools_payload)
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.flush()
|
||||
|
|
@ -620,6 +626,21 @@ class MCPToolManageService:
|
|||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
return MCPToolManageService._reconnect_with_url(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
|
|
@ -642,9 +663,16 @@ class MCPToolManageService:
|
|||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
# Ensure tool descriptions are non-null in payload
|
||||
tools_payload = []
|
||||
for t in tools:
|
||||
d = t.model_dump()
|
||||
if d.get("description") is None:
|
||||
d["description"] = ""
|
||||
tools_payload.append(d)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
tools=json.dumps(tools_payload),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
"""Unit tests for load balancing credential validation APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Reload controller module with lightweight decorators for testing."""
|
||||
|
||||
from controllers.console import console_ns, wraps
|
||||
from libs import login
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
monkeypatch.setattr(login, "login_required", _noop)
|
||||
monkeypatch.setattr(wraps, "setup_required", _noop)
|
||||
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
|
||||
|
||||
def _noop_route(*args, **kwargs): # type: ignore[override]
|
||||
def _decorator(cls):
|
||||
return cls
|
||||
|
||||
return _decorator
|
||||
|
||||
monkeypatch.setattr(console_ns, "route", _noop_route)
|
||||
|
||||
module_name = "controllers.console.workspace.load_balancing_config"
|
||||
sys.modules.pop(module_name, None)
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||
return SimpleNamespace(current_role=role)
|
||||
|
||||
|
||||
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||
user = _mock_user(role)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||
mock_service = MagicMock()
|
||||
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||
return mock_service
|
||||
|
||||
|
||||
def _request_payload():
|
||||
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
|
||||
|
||||
|
||||
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "error", "error": "invalid credentials"}
|
||||
|
||||
|
||||
def test_validate_credentials_requires_privileged_role(
|
||||
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.post(provider="openai")
|
||||
|
||||
|
||||
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
|
||||
provider="openai", config_id="cfg-1"
|
||||
)
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_user_tenant():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(
|
||||
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
|
||||
):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps(
|
||||
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
|
||||
),
|
||||
encrypted_credentials="{}",
|
||||
)
|
||||
|
||||
# Fake service.create_provider -> returns object with id for reload
|
||||
svc = MagicMock()
|
||||
create_result = MagicMock()
|
||||
create_result.id = "provider-1"
|
||||
svc.create_provider.return_value = create_result
|
||||
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
# Patch MCPToolManageService constructed inside controller
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||
payload = {
|
||||
"server_url": "http://example.com/mcp",
|
||||
"name": "demo",
|
||||
"icon": "😀",
|
||||
"icon_type": "emoji",
|
||||
"icon_background": "#000",
|
||||
"server_identifier": "demo-sid",
|
||||
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||
"headers": {},
|
||||
"authentication": {},
|
||||
}
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
),
|
||||
):
|
||||
resp = client.post(
|
||||
"/console/api/workspaces/current/tool-provider/mcp",
|
||||
data=json.dumps(payload),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
|
|
@ -0,0 +1,327 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||
PGVector,
|
||||
PGVectorConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestPGVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=False,
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test PGVector initialization."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
assert pgvector._collection_name == self.collection_name
|
||||
assert pgvector.table_name == f"embedding_{self.collection_name}"
|
||||
assert pgvector.get_type() == "pgvector"
|
||||
assert pgvector.pool is not None
|
||||
assert pgvector.pg_bigm is False
|
||||
assert pgvector.index_hash is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init_with_pg_bigm(self, mock_pool_class):
|
||||
"""Test PGVector initialization with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
|
||||
assert pgvector.pg_bigm is True
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_basic(self, mock_redis, mock_pool_class):
|
||||
"""Test basic collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify SQL execution calls
|
||||
assert mock_cursor.execute.called
|
||||
|
||||
# Check that CREATE TABLE was called with correct dimension
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(1536)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that CREATE INDEX was called (dimension <= 2000)
|
||||
create_index_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
|
||||
]
|
||||
assert len(create_index_calls) == 1
|
||||
|
||||
# Verify Redis cache was set
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with dimension > 2000 (no HNSW index)."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(3072) # Dimension > 2000
|
||||
|
||||
# Check that CREATE TABLE was called
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(3072)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that HNSW index was NOT created (dimension > 2000)
|
||||
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
|
||||
assert len(hnsw_index_calls) == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that pg_bigm index was created
|
||||
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
|
||||
assert len(bigm_index_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
|
||||
"""Test that vector extension is created if it doesn't exist."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
# First call: vector extension doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that CREATE EXTENSION was called
|
||||
create_extension_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
|
||||
]
|
||||
assert len(create_extension_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
|
||||
"""Test that collection creation is skipped when cache exists."""
|
||||
# Mock Redis operations - cache exists
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = 1 # Cache exists
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that no SQL was executed (early return due to cache)
|
||||
assert mock_cursor.execute.call_count == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
|
||||
"""Test that Redis lock is used during collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify Redis lock was acquired with correct lock name
|
||||
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
|
||||
|
||||
# Verify lock context manager was entered and exited
|
||||
mock_lock.__enter__.assert_called_once()
|
||||
mock_lock.__exit__.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_get_cursor_context_manager(self, mock_pool_class):
|
||||
"""Test that _get_cursor properly manages connection lifecycle."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
with pgvector._get_cursor() as cur:
|
||||
assert cur == mock_cursor
|
||||
|
||||
# Verify connection lifecycle methods were called
|
||||
mock_pool.getconn.assert_called_once()
|
||||
mock_cursor.close.assert_called_once()
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_pool.putconn.assert_called_once_with(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config_override",
|
||||
[
|
||||
{"host": ""}, # Test empty host
|
||||
{"port": 0}, # Test invalid port
|
||||
{"user": ""}, # Test empty user
|
||||
{"password": ""}, # Test empty password
|
||||
{"database": ""}, # Test empty database
|
||||
{"min_connection": 0}, # Test invalid min_connection
|
||||
{"max_connection": 0}, # Test invalid max_connection
|
||||
{"min_connection": 10, "max_connection": 5}, # Test min > max
|
||||
],
|
||||
)
|
||||
def test_config_validation_parametrized(invalid_config_override):
|
||||
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
"database": "test_db",
|
||||
"min_connection": 1,
|
||||
"max_connection": 5,
|
||||
}
|
||||
config.update(invalid_config_override)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PGVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,873 @@
|
|||
"""
|
||||
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||
|
||||
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||
filter expressions based on metadata filtering conditions.
|
||||
|
||||
Conditions Tested:
|
||||
==================
|
||||
1. **String Conditions**: contains, not contains, start with, end with
|
||||
2. **Equality Conditions**: is / =, is not / ≠
|
||||
3. **Null Conditions**: empty, not empty
|
||||
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||
5. **List Conditions**: in
|
||||
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||
|
||||
Test Architecture:
|
||||
==================
|
||||
- Direct instantiation of DatasetRetrieval
|
||||
- Mocking of DatasetDocument model attributes
|
||||
- Verification of SQLAlchemy filter expressions
|
||||
- Follows Arrange-Act-Assert (AAA) pattern
|
||||
|
||||
Running Tests:
|
||||
==============
|
||||
# Run all tests in this module
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||
|
||||
# Run a specific test
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
|
||||
class TestProcessMetadataFilterFunc:
|
||||
"""
|
||||
Comprehensive test suite for process_metadata_filter_func method.
|
||||
|
||||
This test class validates all metadata filtering conditions supported by
|
||||
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||
null checks, and list operations.
|
||||
|
||||
Method Signature:
|
||||
==================
|
||||
def process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
) -> list:
|
||||
|
||||
The method builds SQLAlchemy filter expressions by:
|
||||
1. Validating value is not None (except for empty/not empty conditions)
|
||||
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||
4. Returning the updated filters list
|
||||
|
||||
Mocking Strategy:
|
||||
==================
|
||||
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||
- Verify filter expressions are created correctly
|
||||
- Test with various data types (str, int, float, list)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval(self):
|
||||
"""
|
||||
Create a DatasetRetrieval instance for testing.
|
||||
|
||||
Returns:
|
||||
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||
"""
|
||||
return DatasetRetrieval()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_metadata(self):
|
||||
"""
|
||||
Mock the DatasetDocument.doc_metadata JSON field.
|
||||
|
||||
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||
JSON fields. We mock this to avoid database dependencies.
|
||||
|
||||
Returns:
|
||||
Mock: Mocked doc_metadata attribute
|
||||
"""
|
||||
mock_metadata_field = MagicMock()
|
||||
|
||||
# Create mock for string access
|
||||
mock_string_access = MagicMock()
|
||||
mock_string_access.like = MagicMock()
|
||||
mock_string_access.notlike = MagicMock()
|
||||
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for float access (for numeric comparisons)
|
||||
mock_float_access = MagicMock()
|
||||
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for null checks
|
||||
mock_null_access = MagicMock()
|
||||
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Setup __getitem__ to return appropriate mock based on usage
|
||||
def getitem_side_effect(name):
|
||||
if name in ["author", "title", "category"]:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
# ==================== String Condition Tests ====================
|
||||
|
||||
def test_contains_condition_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'contains' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "John"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_contains_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with NOT LIKE expression
|
||||
- Pattern matching uses %value% syntax with negation
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not contains"
|
||||
metadata_name = "title"
|
||||
value = "banned"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_start_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'start with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "start with"
|
||||
metadata_name = "category"
|
||||
value = "tech"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_end_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'end with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "end with"
|
||||
metadata_name = "filename"
|
||||
value = ".pdf"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Equality Condition Tests ====================
|
||||
|
||||
def test_is_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' (=) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with equality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = "Jane Doe"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_equals_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test '=' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is' condition
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "="
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_int_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with integer value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_float_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with float value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "price"
|
||||
value = 19.99
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' (≠) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with inequality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "author"
|
||||
value = "Unknown"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test '≠' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is not' condition
|
||||
- Inequality expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≠"
|
||||
metadata_name = "category"
|
||||
value = "archived"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' condition with numeric value.
|
||||
|
||||
Verifies:
|
||||
- Numeric inequality comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Null Condition Tests ====================
|
||||
|
||||
def test_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'empty' condition (null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "empty"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not empty' condition (not null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NOT NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not empty"
|
||||
metadata_name = "description"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Numeric Comparison Tests ====================
|
||||
|
||||
def test_before_condition(self, retrieval):
|
||||
"""
|
||||
Test 'before' (<) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "before"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '<' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'before' condition
|
||||
- Less than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "price"
|
||||
value = 100.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_after_condition(self, retrieval):
|
||||
"""
|
||||
Test 'after' (>) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "after"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '>' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'after' condition
|
||||
- Greater than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≤' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≤"
|
||||
metadata_name = "price"
|
||||
value = 50.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '<=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≤' condition
|
||||
- Less than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<="
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≥' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≥"
|
||||
metadata_name = "rating"
|
||||
value = 3.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '>=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≥' condition
|
||||
- Greater than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== List/In Condition Tests ====================
|
||||
|
||||
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with comma-separated string value.
|
||||
|
||||
Verifies:
|
||||
- String is split into list
|
||||
- Whitespace is trimmed from each value
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "tech, science, AI "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_list_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with list value.
|
||||
|
||||
Verifies:
|
||||
- List is processed correctly
|
||||
- None values are filtered out
|
||||
- IN expression is created with valid values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "tags"
|
||||
value = ["python", "javascript", None, "golang"]
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_tuple_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with tuple value.
|
||||
|
||||
Verifies:
|
||||
- Tuple is processed like a list
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ("tech", "science", "ai")
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_empty_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with empty string value.
|
||||
|
||||
Verifies:
|
||||
- Empty string results in literal(False) filter
|
||||
- No valid values to match
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
# Verify it's a literal(False) expression
|
||||
# This is a bit tricky to test without access to the actual expression
|
||||
|
||||
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with whitespace-only string value.
|
||||
|
||||
Verifies:
|
||||
- Whitespace-only string results in literal(False) filter
|
||||
- All values are stripped and filtered out
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = " , , "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_single_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with single non-comma string.
|
||||
|
||||
Verifies:
|
||||
- Single string is treated as single-item list
|
||||
- IN expression is created with one value
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with conditions that require value.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values (except empty/not empty)
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0 # No filter added
|
||||
|
||||
def test_none_value_with_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with 'is' (=) condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_none_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "year"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_existing_filters_preserved(self, retrieval):
|
||||
"""
|
||||
Test that existing filters are preserved.
|
||||
|
||||
Verifies:
|
||||
- Existing filters in the list are not removed
|
||||
- New filters are appended to the list
|
||||
"""
|
||||
existing_filter = MagicMock()
|
||||
filters = [existing_filter]
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 2
|
||||
assert filters[0] == existing_filter
|
||||
|
||||
def test_multiple_filters_accumulated(self, retrieval):
|
||||
"""
|
||||
Test multiple calls to accumulate filters.
|
||||
|
||||
Verifies:
|
||||
- Each call adds a new filter to the list
|
||||
- All filters are preserved across calls
|
||||
"""
|
||||
filters = []
|
||||
|
||||
# First filter
|
||||
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||
assert len(filters) == 1
|
||||
|
||||
# Second filter
|
||||
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||
assert len(filters) == 2
|
||||
|
||||
# Third filter
|
||||
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||
assert len(filters) == 3
|
||||
|
||||
def test_unknown_condition(self, retrieval):
|
||||
"""
|
||||
Test unknown/unsupported condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for unknown conditions
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "unknown_condition"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_empty_string_value_with_contains(self, retrieval):
|
||||
"""
|
||||
Test empty string value with 'contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filter is added even with empty string
|
||||
- LIKE expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_special_characters_in_value(self, retrieval):
|
||||
"""
|
||||
Test special characters in value string.
|
||||
|
||||
Verifies:
|
||||
- Special characters are handled in value
|
||||
- LIKE expression is created correctly
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "title"
|
||||
value = "C++ & Python's features"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test zero value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Zero is treated as valid value
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "price"
|
||||
value = 0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test negative value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Negative numbers are handled correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "temperature"
|
||||
value = -10.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_float_value_with_integer_comparison(self, retrieval):
|
||||
"""
|
||||
Test float value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Float values work correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
|
@ -399,6 +399,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=*
|
|||
COOKIE_DOMAIN=
|
||||
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY=5
|
||||
|
||||
# ------------------------------
|
||||
# File Storage Configuration
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ x-shared-env: &shared-api-worker-env
|
|||
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
||||
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5}
|
||||
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
||||
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
||||
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
||||
|
|
|
|||
|
|
@ -73,3 +73,6 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50
|
|||
|
||||
# The API key of amplitude
|
||||
NEXT_PUBLIC_AMPLITUDE_API_KEY=
|
||||
|
||||
# number of concurrency
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY=5
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ const DatasetConfig: FC = () => {
|
|||
}))
|
||||
}, [setDatasetConfigs, datasetConfigsRef])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
|
||||
let operator: ComparisonOperator = ComparisonOperator.is
|
||||
|
||||
if (type === MetadataFilteringVariableType.number)
|
||||
|
|
@ -184,6 +184,7 @@ const DatasetConfig: FC = () => {
|
|||
|
||||
const newCondition = {
|
||||
id: uuid4(),
|
||||
metadata_id: id, // Save metadata.id for reliable reference
|
||||
name,
|
||||
comparison_operator: operator,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,141 @@
|
|||
import type { DataSet } from '@/models/datasets'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import { DatasetPermission } from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import SelectDataSet from './index'
|
||||
|
||||
vi.mock('@/i18n-config/i18next-config', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
changeLanguage: vi.fn(),
|
||||
addResourceBundle: vi.fn(),
|
||||
use: vi.fn().mockReturnThis(),
|
||||
init: vi.fn(),
|
||||
addResource: vi.fn(),
|
||||
hasResourceBundle: vi.fn().mockReturnValue(true),
|
||||
},
|
||||
}))
|
||||
const mockUseInfiniteScroll = vi.fn()
|
||||
vi.mock('ahooks', async (importOriginal) => {
|
||||
const actual = await importOriginal()
|
||||
return {
|
||||
...(typeof actual === 'object' && actual !== null ? actual : {}),
|
||||
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
|
||||
}
|
||||
})
|
||||
|
||||
const mockUseInfiniteDatasets = vi.fn()
|
||||
vi.mock('@/service/knowledge/use-dataset', () => ({
|
||||
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-knowledge', () => ({
|
||||
useKnowledge: () => ({
|
||||
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
|
||||
}),
|
||||
}))
|
||||
|
||||
const baseProps = {
|
||||
isShow: true,
|
||||
onClose: vi.fn(),
|
||||
selectedIds: [] as string[],
|
||||
onSelect: vi.fn(),
|
||||
}
|
||||
|
||||
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
|
||||
id: 'dataset-id',
|
||||
name: 'Dataset Name',
|
||||
provider: 'internal',
|
||||
icon_info: {
|
||||
icon_type: 'emoji',
|
||||
icon: '💾',
|
||||
icon_background: '#fff',
|
||||
icon_url: '',
|
||||
},
|
||||
embedding_available: true,
|
||||
is_multimodal: false,
|
||||
description: '',
|
||||
permission: DatasetPermission.allTeamMembers,
|
||||
indexing_technique: IndexingType.ECONOMICAL,
|
||||
retrieval_model_dict: {
|
||||
search_method: RETRIEVE_METHOD.fullText,
|
||||
top_k: 5,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_model_name: '',
|
||||
reranking_provider_name: '',
|
||||
},
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
...overrides,
|
||||
} as DataSet)
|
||||
|
||||
describe('SelectDataSet', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders dataset entries, allows selection, and fires onSelect', async () => {
|
||||
const datasetOne = makeDataset({
|
||||
id: 'set-1',
|
||||
name: 'Dataset One',
|
||||
is_multimodal: true,
|
||||
indexing_technique: IndexingType.ECONOMICAL,
|
||||
})
|
||||
const datasetTwo = makeDataset({
|
||||
id: 'set-2',
|
||||
name: 'Hidden Dataset',
|
||||
embedding_available: false,
|
||||
provider: 'external',
|
||||
})
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
const onSelect = vi.fn()
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('Dataset One')).toBeInTheDocument()
|
||||
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText('Dataset One'))
|
||||
})
|
||||
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
|
||||
|
||||
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
|
||||
await act(async () => {
|
||||
fireEvent.click(addButton)
|
||||
})
|
||||
expect(onSelect).toHaveBeenCalledWith([datasetOne])
|
||||
})
|
||||
|
||||
it('shows empty state when no datasets are available and disables add', async () => {
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
|
||||
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
import type { IPromptValuePanelProps } from './index'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
|
||||
import PromptValuePanel from './index'
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
|
||||
__esModule: true,
|
||||
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
|
||||
<button type="button" onClick={onFeatureBarClick}>
|
||||
feature bar
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
const mockSetShowAppConfigureFeaturesModal = vi.fn()
|
||||
const mockUseStore = vi.mocked(useStore)
|
||||
const mockSetInputs = vi.fn()
|
||||
const mockOnSend = vi.fn()
|
||||
|
||||
const promptVariables = [
|
||||
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
|
||||
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
|
||||
] as const
|
||||
|
||||
const baseContextValue: any = {
|
||||
modelModeType: ModelModeType.completion,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: 'prompt template',
|
||||
prompt_variables: promptVariables,
|
||||
},
|
||||
},
|
||||
setInputs: mockSetInputs,
|
||||
mode: AppModeEnum.COMPLETION,
|
||||
isAdvancedMode: false,
|
||||
completionPromptConfig: {
|
||||
prompt: { text: 'completion' },
|
||||
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
|
||||
},
|
||||
chatPromptConfig: { prompt: [] },
|
||||
} as any
|
||||
|
||||
const defaultProps: IPromptValuePanelProps = {
|
||||
appType: AppModeEnum.COMPLETION,
|
||||
onSend: mockOnSend,
|
||||
inputs: { textVar: 'initial', boolVar: false },
|
||||
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
|
||||
onVisionFilesChange: vi.fn(),
|
||||
}
|
||||
|
||||
const renderPanel = (options: {
|
||||
context?: Partial<typeof baseContextValue>
|
||||
props?: Partial<IPromptValuePanelProps>
|
||||
} = {}) => {
|
||||
const contextValue = { ...baseContextValue, ...options.context }
|
||||
const props = { ...defaultProps, ...options.props }
|
||||
return render(
|
||||
<ConfigContext.Provider value={contextValue}>
|
||||
<PromptValuePanel {...props} />
|
||||
</ConfigContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('PromptValuePanel', () => {
|
||||
beforeEach(() => {
|
||||
mockUseStore.mockImplementation(selector => selector({
|
||||
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
|
||||
appSidebarExpand: '',
|
||||
currentLogModalActiveTab: 'prompt',
|
||||
showPromptLogModal: false,
|
||||
showAgentLogModal: false,
|
||||
setShowPromptLogModal: vi.fn(),
|
||||
setShowAgentLogModal: vi.fn(),
|
||||
showMessageLogModal: false,
|
||||
showAppConfigureFeaturesModal: false,
|
||||
} as any))
|
||||
mockSetInputs.mockClear()
|
||||
mockOnSend.mockClear()
|
||||
mockSetShowAppConfigureFeaturesModal.mockClear()
|
||||
})
|
||||
|
||||
it('updates inputs, clears values, and triggers run when ready', async () => {
|
||||
renderPanel()
|
||||
|
||||
const textInput = screen.getByPlaceholderText('Text Var')
|
||||
fireEvent.change(textInput, { target: { value: 'updated' } })
|
||||
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
|
||||
|
||||
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
|
||||
fireEvent.click(clearButton)
|
||||
|
||||
expect(mockSetInputs).toHaveBeenLastCalledWith({
|
||||
textVar: '',
|
||||
boolVar: '',
|
||||
})
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||
expect(runButton).not.toBeDisabled()
|
||||
fireEvent.click(runButton)
|
||||
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
|
||||
})
|
||||
|
||||
it('disables run when mode is not completion', () => {
|
||||
renderPanel({
|
||||
context: {
|
||||
mode: AppModeEnum.CHAT,
|
||||
},
|
||||
props: {
|
||||
appType: AppModeEnum.CHAT,
|
||||
},
|
||||
})
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||
expect(runButton).toBeDisabled()
|
||||
fireEvent.click(runButton)
|
||||
expect(mockOnSend).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import type { PromptVariable } from '@/models/debug'
|
||||
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { replaceStringWithValues } from './utils'
|
||||
|
||||
const promptVariables: PromptVariable[] = [
|
||||
{ key: 'user', name: 'User', type: 'string' },
|
||||
{ key: 'topic', name: 'Topic', type: 'string' },
|
||||
]
|
||||
|
||||
describe('replaceStringWithValues', () => {
|
||||
it('should replace placeholders when inputs have values', () => {
|
||||
const template = 'Hello {{user}} talking about {{topic}}'
|
||||
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
|
||||
expect(result).toBe('Hello Alice talking about cats')
|
||||
})
|
||||
|
||||
it('should use prompt variable name when value is missing', () => {
|
||||
const template = 'Hi {{user}} from {{topic}}'
|
||||
const result = replaceStringWithValues(template, promptVariables, {})
|
||||
expect(result).toBe('Hi {{User}} from {{Topic}}')
|
||||
})
|
||||
|
||||
it('should leave placeholder untouched when no variable is defined', () => {
|
||||
const template = 'Unknown {{missing}} placeholder'
|
||||
const result = replaceStringWithValues(template, promptVariables, {})
|
||||
expect(result).toBe('Unknown {{missing}} placeholder')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import CreateAppModal from './index'
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounceFn: (fn: (...args: any[]) => any) => {
|
||||
const run = (...args: any[]) => fn(...args)
|
||||
const cancel = vi.fn()
|
||||
const flush = vi.fn()
|
||||
return { run, cancel, flush }
|
||||
},
|
||||
useKeyPress: vi.fn(),
|
||||
useHover: () => false,
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
createApp: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/utils/app-redirection', () => ({
|
||||
getRedirection: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => () => '/guides',
|
||||
}))
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({ theme: 'light' }),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockUseRouter = vi.mocked(useRouter)
|
||||
const mockPush = vi.fn()
|
||||
const mockCreateApp = vi.mocked(createApp)
|
||||
const mockTrackEvent = vi.mocked(trackEvent)
|
||||
const mockGetRedirection = vi.mocked(getRedirection)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
|
||||
const defaultPlanUsage = {
|
||||
buildApps: 0,
|
||||
teamMembers: 0,
|
||||
annotatedResponse: 0,
|
||||
documentsUploadQuota: 0,
|
||||
apiRateLimit: 0,
|
||||
triggerEvents: 0,
|
||||
vectorSpace: 0,
|
||||
}
|
||||
|
||||
const renderModal = () => {
|
||||
const onClose = vi.fn()
|
||||
const onSuccess = vi.fn()
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
return { onClose, onSuccess }
|
||||
}
|
||||
|
||||
describe('CreateAppModal', () => {
|
||||
const mockSetItem = vi.fn()
|
||||
const originalLocalStorage = window.localStorage
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseRouter.mockReturnValue({ push: mockPush } as any)
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: AppModeEnum.ADVANCED_CHAT,
|
||||
usage: defaultPlanUsage,
|
||||
total: { ...defaultPlanUsage, buildApps: 1 },
|
||||
reset: {},
|
||||
},
|
||||
enableBilling: true,
|
||||
} as any)
|
||||
mockUseAppContext.mockReturnValue({
|
||||
isCurrentWorkspaceEditor: true,
|
||||
} as any)
|
||||
mockSetItem.mockClear()
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: {
|
||||
setItem: mockSetItem,
|
||||
getItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
key: vi.fn(),
|
||||
length: 0,
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: originalLocalStorage,
|
||||
writable: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('creates an app, notifies success, and fires callbacks', async () => {
|
||||
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
|
||||
mockCreateApp.mockResolvedValue(mockApp as any)
|
||||
const { onClose, onSuccess } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||
name: 'My App',
|
||||
description: '',
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#FFEAD5',
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
}))
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
|
||||
app_mode: AppModeEnum.ADVANCED_CHAT,
|
||||
description: '',
|
||||
})
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
|
||||
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
|
||||
})
|
||||
|
||||
it('shows error toast when creation fails', async () => {
|
||||
mockCreateApp.mockRejectedValue(new Error('boom'))
|
||||
const { onClose } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||
expect(onClose).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
import type { SiteInfo } from '@/models/share'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import * as React from 'react'
|
||||
|
||||
import { act } from 'react'
|
||||
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import Embedded from './index'
|
||||
|
||||
vi.mock('./style.module.css', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
option: 'option',
|
||||
active: 'active',
|
||||
iframeIcon: 'iframeIcon',
|
||||
scriptsIcon: 'scriptsIcon',
|
||||
chromePluginIcon: 'chromePluginIcon',
|
||||
pluginInstallIcon: 'pluginInstallIcon',
|
||||
},
|
||||
}))
|
||||
const mockThemeBuilder = {
|
||||
buildTheme: vi.fn(),
|
||||
theme: {
|
||||
primaryColor: '#123456',
|
||||
},
|
||||
}
|
||||
const mockUseAppContext = vi.fn(() => ({
|
||||
langGeniusVersionInfo: {
|
||||
current_env: 'PRODUCTION',
|
||||
current_version: '',
|
||||
latest_version: '',
|
||||
release_date: '',
|
||||
release_notes: '',
|
||||
version: '',
|
||||
can_auto_update: false,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('copy-to-clipboard', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
|
||||
useThemeContext: () => mockThemeBuilder,
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => mockUseAppContext(),
|
||||
}))
|
||||
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
const mockedCopy = vi.mocked(copy)
|
||||
|
||||
const siteInfo: SiteInfo = {
|
||||
title: 'test site',
|
||||
chat_color_theme: '#000000',
|
||||
chat_color_theme_inverted: false,
|
||||
}
|
||||
|
||||
const baseProps = {
|
||||
isShow: true,
|
||||
siteInfo,
|
||||
onClose: vi.fn(),
|
||||
appBaseUrl: 'https://app.example.com',
|
||||
accessToken: 'token',
|
||||
className: 'custom-modal',
|
||||
}
|
||||
|
||||
const getCopyButton = () => {
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const actionButton = buttons.find(button => button.className.includes('action-btn'))
|
||||
expect(actionButton).toBeDefined()
|
||||
return actionButton!
|
||||
}
|
||||
|
||||
describe('Embedded', () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWindowOpen.mockClear()
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
mockWindowOpen.mockRestore()
|
||||
})
|
||||
|
||||
it('builds theme and copies iframe snippet', async () => {
|
||||
await act(async () => {
|
||||
render(<Embedded {...baseProps} />)
|
||||
})
|
||||
|
||||
const actionButton = getCopyButton()
|
||||
const innerDiv = actionButton.querySelector('div')
|
||||
act(() => {
|
||||
fireEvent.click(innerDiv ?? actionButton)
|
||||
})
|
||||
|
||||
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
|
||||
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
|
||||
})
|
||||
|
||||
it('opens chrome plugin store link when chrome option selected', async () => {
|
||||
await act(async () => {
|
||||
render(<Embedded {...baseProps} />)
|
||||
})
|
||||
|
||||
const optionButtons = document.body.querySelectorAll('[class*="option"]')
|
||||
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
act(() => {
|
||||
fireEvent.click(optionButtons[2])
|
||||
})
|
||||
|
||||
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
|
||||
act(() => {
|
||||
fireEvent.click(chromeText)
|
||||
})
|
||||
|
||||
expect(mockWindowOpen).toHaveBeenCalledWith(
|
||||
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
|
||||
'_blank',
|
||||
'noopener,noreferrer',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import type { ISavedItemsProps } from './index'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import SavedItems from './index'
|
||||
|
||||
vi.mock('copy-to-clipboard', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useParams: () => ({}),
|
||||
usePathname: () => '/',
|
||||
}))
|
||||
|
||||
const mockCopy = vi.mocked(copy)
|
||||
const toastNotifySpy = vi.spyOn(Toast, 'notify')
|
||||
|
||||
const baseProps: ISavedItemsProps = {
|
||||
list: [
|
||||
{ id: '1', answer: 'hello world' },
|
||||
],
|
||||
isShowTextToSpeech: true,
|
||||
onRemove: vi.fn(),
|
||||
onStartCreateContent: vi.fn(),
|
||||
}
|
||||
|
||||
describe('SavedItems', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toastNotifySpy.mockClear()
|
||||
})
|
||||
|
||||
it('renders saved answers with metadata and controls', () => {
|
||||
const { container } = render(<SavedItems {...baseProps} />)
|
||||
|
||||
const markdownElement = container.querySelector('.markdown-body')
|
||||
expect(markdownElement).toBeInTheDocument()
|
||||
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
|
||||
|
||||
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
})
|
||||
|
||||
it('copies content and notifies, and triggers remove callback', () => {
|
||||
const handleRemove = vi.fn()
|
||||
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
|
||||
|
||||
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
|
||||
const copyButton = actionButtons[1]
|
||||
const deleteButton = actionButtons[2]
|
||||
|
||||
fireEvent.click(copyButton)
|
||||
expect(mockCopy).toHaveBeenCalledWith('hello world')
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
|
||||
|
||||
fireEvent.click(deleteButton)
|
||||
expect(handleRemove).toHaveBeenCalledWith('1')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import NoData from './index'
|
||||
|
||||
describe('NoData', () => {
|
||||
it('renders title/description and calls callback when button clicked', () => {
|
||||
const handleStart = vi.fn()
|
||||
render(<NoData onStartCreateContent={handleStart} />)
|
||||
|
||||
const title = screen.getByText('share.generation.savedNoData.title')
|
||||
const description = screen.getByText('share.generation.savedNoData.description')
|
||||
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
|
||||
|
||||
expect(title).toBeInTheDocument()
|
||||
expect(description).toBeInTheDocument()
|
||||
expect(button).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(button)
|
||||
expect(handleStart).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { updateCurrentWorkspace } from '@/service/common'
|
||||
import CustomWebAppBrand from './index'
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
useToastContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/common', () => ({
|
||||
updateCurrentWorkspace: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/image-uploader/utils', () => ({
|
||||
imageUpload: vi.fn(),
|
||||
getImageUploadErrorMessage: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockUseToastContext = vi.mocked(useToastContext)
|
||||
const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
|
||||
const mockImageUpload = vi.mocked(imageUpload)
|
||||
const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage)
|
||||
|
||||
const defaultPlanUsage = {
|
||||
buildApps: 0,
|
||||
teamMembers: 0,
|
||||
annotatedResponse: 0,
|
||||
documentsUploadQuota: 0,
|
||||
apiRateLimit: 0,
|
||||
triggerEvents: 0,
|
||||
vectorSpace: 0,
|
||||
}
|
||||
|
||||
const renderComponent = () => render(<CustomWebAppBrand />)
|
||||
|
||||
describe('CustomWebAppBrand', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseToastContext.mockReturnValue({ notify: mockNotify } as any)
|
||||
mockUpdateCurrentWorkspace.mockResolvedValue({} as any)
|
||||
mockUseAppContext.mockReturnValue({
|
||||
currentWorkspace: {
|
||||
custom_config: {
|
||||
replace_webapp_logo: 'https://example.com/replace.png',
|
||||
remove_webapp_brand: false,
|
||||
},
|
||||
},
|
||||
mutateCurrentWorkspace: vi.fn(),
|
||||
isCurrentWorkspaceManager: true,
|
||||
} as any)
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: Plan.professional,
|
||||
usage: defaultPlanUsage,
|
||||
total: defaultPlanUsage,
|
||||
reset: {},
|
||||
},
|
||||
enableBilling: false,
|
||||
} as any)
|
||||
const systemFeaturesState = {
|
||||
branding: {
|
||||
enabled: true,
|
||||
workspace_logo: 'https://example.com/workspace-logo.png',
|
||||
},
|
||||
}
|
||||
mockUseGlobalPublicStore.mockImplementation(selector => selector ? selector({ systemFeatures: systemFeaturesState } as any) : { systemFeatures: systemFeaturesState })
|
||||
mockGetImageUploadErrorMessage.mockReturnValue('upload error')
|
||||
})
|
||||
|
||||
it('disables upload controls when the user cannot manage the workspace', () => {
|
||||
mockUseAppContext.mockReturnValue({
|
||||
currentWorkspace: {
|
||||
custom_config: {
|
||||
replace_webapp_logo: '',
|
||||
remove_webapp_brand: false,
|
||||
},
|
||||
},
|
||||
mutateCurrentWorkspace: vi.fn(),
|
||||
isCurrentWorkspaceManager: false,
|
||||
} as any)
|
||||
|
||||
const { container } = renderComponent()
|
||||
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
||||
expect(fileInput).toBeDisabled()
|
||||
})
|
||||
|
||||
it('toggles remove brand switch and calls the backend + mutate', async () => {
|
||||
const mutateMock = vi.fn()
|
||||
mockUseAppContext.mockReturnValue({
|
||||
currentWorkspace: {
|
||||
custom_config: {
|
||||
replace_webapp_logo: '',
|
||||
remove_webapp_brand: false,
|
||||
},
|
||||
},
|
||||
mutateCurrentWorkspace: mutateMock,
|
||||
isCurrentWorkspaceManager: true,
|
||||
} as any)
|
||||
|
||||
renderComponent()
|
||||
const switchInput = screen.getByRole('switch')
|
||||
fireEvent.click(switchInput)
|
||||
|
||||
await waitFor(() => expect(mockUpdateCurrentWorkspace).toHaveBeenCalledWith({
|
||||
url: '/workspaces/custom-config',
|
||||
body: { remove_webapp_brand: true },
|
||||
}))
|
||||
await waitFor(() => expect(mutateMock).toHaveBeenCalled())
|
||||
})
|
||||
|
||||
it('shows cancel/apply buttons after successful upload and cancels properly', async () => {
|
||||
mockImageUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }) => {
|
||||
onProgressCallback(50)
|
||||
onSuccessCallback({ id: 'new-logo' })
|
||||
})
|
||||
|
||||
const { container } = renderComponent()
|
||||
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
||||
const testFile = new File(['content'], 'logo.png', { type: 'image/png' })
|
||||
fireEvent.change(fileInput, { target: { files: [testFile] } })
|
||||
|
||||
await waitFor(() => expect(mockImageUpload).toHaveBeenCalled())
|
||||
await waitFor(() => screen.getByRole('button', { name: 'custom.apply' }))
|
||||
|
||||
const cancelButton = screen.getByRole('button', { name: 'common.operation.cancel' })
|
||||
fireEvent.click(cancelButton)
|
||||
|
||||
await waitFor(() => expect(screen.queryByRole('button', { name: 'custom.apply' })).toBeNull())
|
||||
})
|
||||
})
|
||||
|
|
@ -26,7 +26,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo'
|
|||
import Toast from '@/app/components/base/toast'
|
||||
import Res from '@/app/components/share/text-generation/result'
|
||||
import RunOnce from '@/app/components/share/text-generation/run-once'
|
||||
import { appDefaultIconBackground, DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import { appDefaultIconBackground, BATCH_CONCURRENCY, DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useAppFavicon } from '@/hooks/use-app-favicon'
|
||||
|
|
@ -43,7 +43,7 @@ import MenuDropdown from './menu-dropdown'
|
|||
import RunBatch from './run-batch'
|
||||
import ResDownload from './run-batch/res-download'
|
||||
|
||||
const GROUP_SIZE = 5 // to avoid RPM(Request per minute) limit. The group task finished then the next group.
|
||||
const GROUP_SIZE = BATCH_CONCURRENCY // to avoid RPM(Request per minute) limit. The group task finished then the next group.
|
||||
enum TaskStatus {
|
||||
pending = 'pending',
|
||||
running = 'running',
|
||||
|
|
|
|||
|
|
@ -62,8 +62,15 @@ const ConditionItem = ({
|
|||
}, [onRemoveCondition, condition.id])
|
||||
|
||||
const currentMetadata = useMemo(() => {
|
||||
// Try to match by metadata_id first (reliable reference)
|
||||
if (condition.metadata_id) {
|
||||
const found = metadataList.find(metadata => metadata.id === condition.metadata_id)
|
||||
if (found)
|
||||
return found
|
||||
}
|
||||
// Fallback to name matching for backward compatibility with old conditions
|
||||
return metadataList.find(metadata => metadata.name === condition.name)
|
||||
}, [metadataList, condition.name])
|
||||
}, [metadataList, condition.metadata_id, condition.name])
|
||||
|
||||
const handleConditionOperatorChange = useCallback((operator: ComparisonOperator) => {
|
||||
onUpdateCondition?.(
|
||||
|
|
|
|||
|
|
@ -27,11 +27,17 @@ const MetadataTrigger = ({
|
|||
useEffect(() => {
|
||||
if (selectedDatasetsLoaded) {
|
||||
conditions.forEach((condition) => {
|
||||
if (!metadataList.find(metadata => metadata.name === condition.name))
|
||||
// First try to match by metadata_id for reliable reference
|
||||
const foundById = condition.metadata_id && metadataList.find(metadata => metadata.id === condition.metadata_id)
|
||||
// Fallback to name matching only for backward compatibility with old conditions
|
||||
const foundByName = !condition.metadata_id && metadataList.find(metadata => metadata.name === condition.name)
|
||||
|
||||
// Only remove condition if both metadata_id and name matching fail
|
||||
if (!foundById && !foundByName)
|
||||
handleRemoveCondition(condition.id)
|
||||
})
|
||||
}
|
||||
}, [metadataList, handleRemoveCondition, selectedDatasetsLoaded])
|
||||
}, [metadataFilteringConditions, metadataList, handleRemoveCondition, selectedDatasetsLoaded])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ export enum MetadataFilteringVariableType {
|
|||
export type MetadataFilteringCondition = {
|
||||
id: string
|
||||
name: string
|
||||
metadata_id?: string
|
||||
comparison_operator: ComparisonOperator
|
||||
value?: string | number
|
||||
}
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
}))
|
||||
}, [setInputs])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
|
||||
let operator: ComparisonOperator = ComparisonOperator.is
|
||||
|
||||
if (type === MetadataFilteringVariableType.number)
|
||||
|
|
@ -313,6 +313,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
|
||||
const newCondition = {
|
||||
id: uuid4(),
|
||||
metadata_id: id, // Save metadata.id for reliable reference
|
||||
name,
|
||||
comparison_operator: operator,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ const LocaleLayout = async ({
|
|||
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL,
|
||||
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID,
|
||||
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN,
|
||||
[DatasetAttr.DATA_PUBLIC_BATCH_CONCURRENCY]: process.env.NEXT_PUBLIC_BATCH_CONCURRENCY,
|
||||
}
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -164,6 +164,13 @@ const COOKIE_DOMAIN = getStringConfig(
|
|||
DatasetAttr.DATA_PUBLIC_COOKIE_DOMAIN,
|
||||
'',
|
||||
).trim()
|
||||
|
||||
export const BATCH_CONCURRENCY = getNumberConfig(
|
||||
process.env.NEXT_PUBLIC_BATCH_CONCURRENCY,
|
||||
DatasetAttr.DATA_PUBLIC_BATCH_CONCURRENCY,
|
||||
5, // default
|
||||
)
|
||||
|
||||
export const CSRF_COOKIE_NAME = () => {
|
||||
if (COOKIE_DOMAIN)
|
||||
return 'csrf_token'
|
||||
|
|
|
|||
|
|
@ -131,4 +131,5 @@ export enum DatasetAttr {
|
|||
NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL = 'next-public-zendesk-field-id-email',
|
||||
NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID = 'next-public-zendesk-field-id-workspace-id',
|
||||
NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN = 'next-public-zendesk-field-id-plan',
|
||||
DATA_PUBLIC_BATCH_CONCURRENCY = 'data-public-batch-concurrency',
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue