Merge branch 'main' into feat/log-formatter

This commit is contained in:
Byron.wang 2025-12-26 20:18:01 +08:00 committed by GitHub
commit 2f67c5aa75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 2421 additions and 162 deletions

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import marshal_with
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
@ -10,19 +8,19 @@ from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,6 +1,8 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@ -10,10 +12,20 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialPayload(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, object]
register_schema_models(console_ns, LoadBalancingCredentialPayload)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
config_id=config_id,
)
except CredentialsValidateFailedError as ex:

View File

@ -1,4 +1,5 @@
import io
import logging
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
@ -17,6 +18,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
logger = logging.getLogger(__name__)
def is_valid_url(url: str) -> bool:
if not url:
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider in transaction
with Session(db.engine) as session, session.begin():
# 1) Create provider in a short transaction (no network I/O inside)
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
# 2) Try to fetch tools immediately after creation so they appear without a second save.
# Perform network I/O outside any DB session to avoid holding locks.
try:
reconnect = MCPToolManageService.reconnect_with_url(
server_url=args["server_url"],
headers=args.get("headers") or {},
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
)
# Update just-created provider with authed/tools in a new short transaction
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
db_provider.authed = reconnect.authed
db_provider.tools = reconnect.tools
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
except Exception:
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
# Final cache invalidation to ensure list views are up to date
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)

View File

@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _, dataset_id):
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token
def patch(self, _, dataset_id):
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
@edit_permission_required
def delete(self, _, dataset_id):
def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
def post(self, _, dataset_id):
def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")

View File

@ -255,7 +255,10 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
if not cur.fetchone():
cur.execute("CREATE EXTENSION vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
self.process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = self.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -1168,8 +1168,9 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
@classmethod
def process_metadata_filter_func(
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
@ -1218,6 +1219,20 @@ class DatasetRetrieval:
case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True
filters.append(literal(condition == "not in"))
else:
op = json_field.in_ if condition == "in" else json_field.notin_
filters.append(op(value_list))
case _:
pass

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -110,5 +110,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)

View File

@ -319,8 +319,14 @@ class MCPToolManageService:
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
# Update database with retrieved tools (ensure description is a non-null string)
tools_payload = []
for tool in tools:
data = tool.model_dump()
if data.get("description") is None:
data["description"] = ""
tools_payload.append(data)
db_provider.tools = json.dumps(tools_payload)
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
@ -620,6 +626,21 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
@staticmethod
def reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
return MCPToolManageService._reconnect_with_url(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
@staticmethod
def _reconnect_with_url(
*,
@ -642,9 +663,16 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
# Ensure tool descriptions are non-null in payload
tools_payload = []
for t in tools:
d = t.model_dump()
if d.get("description") is None:
d["description"] = ""
tools_payload.append(d)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
tools=json.dumps(tools_payload),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:

View File

@ -0,0 +1,145 @@
"""Unit tests for load balancing credential validation APIs."""
from __future__ import annotations
import builtins
import importlib
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView
from werkzeug.exceptions import Forbidden
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
from models.account import TenantAccountRole
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
"""Reload controller module with lightweight decorators for testing."""
from controllers.console import console_ns, wraps
from libs import login
def _noop(func):
return func
monkeypatch.setattr(login, "login_required", _noop)
monkeypatch.setattr(wraps, "setup_required", _noop)
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
def _noop_route(*args, **kwargs): # type: ignore[override]
def _decorator(cls):
return cls
return _decorator
monkeypatch.setattr(console_ns, "route", _noop_route)
module_name = "controllers.console.workspace.load_balancing_config"
sys.modules.pop(module_name, None)
module = importlib.import_module(module_name)
return module
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
return SimpleNamespace(current_role=role)
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
user = _mock_user(role)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
mock_service = MagicMock()
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
return mock_service
def _request_payload():
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
)
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "error", "error": "invalid credentials"}
def test_validate_credentials_requires_privileged_role(
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
):
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
with pytest.raises(Forbidden):
api.post(provider="openai")
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
provider="openai", config_id="cfg-1"
)
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
config_id="cfg-1",
)

View File

@ -0,0 +1,103 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
@pytest.fixture
def _mock_cache():
return
@pytest.fixture
def _mock_user_tenant():
return
@pytest.fixture
def client():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
api = Api(app)
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
db.init_app(app)
# Configure session factory used by controller code
with app.app_context():
configure_session_factory(db.engine)
return app.test_client()
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
)
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
@patch("controllers.console.workspace.tool_providers.Session")
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
):
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,
tools=json.dumps(
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
),
encrypted_credentials="{}",
)
# Fake service.create_provider -> returns object with id for reload
svc = MagicMock()
create_result = MagicMock()
create_result.id = "provider-1"
svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
payload = {
"server_url": "http://example.com/mcp",
"name": "demo",
"icon": "😀",
"icon_type": "emoji",
"icon_background": "#000",
"server_identifier": "demo-sid",
"configuration": {"timeout": 5, "sse_read_timeout": 30},
"headers": {},
"authentication": {},
}
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
),
):
resp = client.post(
"/console/api/workspaces/current/tool-provider/mcp",
data=json.dumps(payload),
content_type="application/json",
)
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]

View File

@ -0,0 +1,327 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector,
PGVectorConfig,
)
class TestPGVector(unittest.TestCase):
def setUp(self):
self.config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=False,
)
self.collection_name = "test_collection"
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init(self, mock_pool_class):
"""Test PGVector initialization."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, self.config)
assert pgvector._collection_name == self.collection_name
assert pgvector.table_name == f"embedding_{self.collection_name}"
assert pgvector.get_type() == "pgvector"
assert pgvector.pool is not None
assert pgvector.pg_bigm is False
assert pgvector.index_hash is not None
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init_with_pg_bigm(self, mock_pool_class):
"""Test PGVector initialization with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, config)
assert pgvector.pg_bigm is True
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_basic(self, mock_redis, mock_pool_class):
"""Test basic collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify SQL execution calls
assert mock_cursor.execute.called
# Check that CREATE TABLE was called with correct dimension
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(1536)" in create_table_calls[0][0][0]
# Check that CREATE INDEX was called (dimension <= 2000)
create_index_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
]
assert len(create_index_calls) == 1
# Verify Redis cache was set
mock_redis.set.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
"""Test collection creation with dimension > 2000 (no HNSW index)."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(3072) # Dimension > 2000
# Check that CREATE TABLE was called
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(3072)" in create_table_calls[0][0][0]
# Check that HNSW index was NOT created (dimension > 2000)
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
assert len(hnsw_index_calls) == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
"""Test collection creation with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, config)
pgvector._create_collection(1536)
# Check that pg_bigm index was created
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
assert len(bigm_index_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
"""Test that vector extension is created if it doesn't exist."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
# First call: vector extension doesn't exist
mock_cursor.fetchone.return_value = None
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that CREATE EXTENSION was called
create_extension_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
]
assert len(create_extension_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
"""Test that collection creation is skipped when cache exists."""
# Mock Redis operations - cache exists
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1 # Cache exists
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that no SQL was executed (early return due to cache)
assert mock_cursor.execute.call_count == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
"""Test that Redis lock is used during collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify Redis lock was acquired with correct lock name
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
# Verify lock context manager was entered and exited
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_get_cursor_context_manager(self, mock_pool_class):
"""Test that _get_cursor properly manages connection lifecycle."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
with pgvector._get_cursor() as cur:
assert cur == mock_cursor
# Verify connection lifecycle methods were called
mock_pool.getconn.assert_called_once()
mock_cursor.close.assert_called_once()
mock_conn.commit.assert_called_once()
mock_pool.putconn.assert_called_once_with(mock_conn)
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"user": ""}, # Test empty user
{"password": ""}, # Test empty password
{"database": ""}, # Test empty database
{"min_connection": 0}, # Test invalid min_connection
{"max_connection": 0}, # Test invalid max_connection
{"min_connection": 10, "max_connection": 5}, # Test min > max
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 5432,
"user": "test_user",
"password": "test_password",
"database": "test_db",
"min_connection": 1,
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
PGVectorConfig(**config)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,873 @@
"""
Unit tests for DatasetRetrieval.process_metadata_filter_func.
This module provides comprehensive test coverage for the process_metadata_filter_func
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
filter expressions based on metadata filtering conditions.
Conditions Tested:
==================
1. **String Conditions**: contains, not contains, start with, end with
2. **Equality Conditions**: is / =, is not /
3. **Null Conditions**: empty, not empty
4. **Numeric Comparisons**: before / <, after / >, / <=, / >=
5. **List Conditions**: in
6. **Edge Cases**: None values, different data types (str, int, float)
Test Architecture:
==================
- Direct instantiation of DatasetRetrieval
- Mocking of DatasetDocument model attributes
- Verification of SQLAlchemy filter expressions
- Follows Arrange-Act-Assert (AAA) pattern
Running Tests:
==============
# Run all tests in this module
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
# Run a specific test
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
TestProcessMetadataFilterFunc::test_contains_condition -v
"""
from unittest.mock import MagicMock
import pytest
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
class TestProcessMetadataFilterFunc:
"""
Comprehensive test suite for process_metadata_filter_func method.
This test class validates all metadata filtering conditions supported by
the DatasetRetrieval class, including string operations, numeric comparisons,
null checks, and list operations.
Method Signature:
==================
def process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
) -> list:
The method builds SQLAlchemy filter expressions by:
1. Validating value is not None (except for empty/not empty conditions)
2. Using DatasetDocument.doc_metadata JSON field operations
3. Adding appropriate SQLAlchemy expressions to the filters list
4. Returning the updated filters list
Mocking Strategy:
==================
- Mock DatasetDocument.doc_metadata to avoid database dependencies
- Verify filter expressions are created correctly
- Test with various data types (str, int, float, list)
"""
@pytest.fixture
def retrieval(self):
"""
Create a DatasetRetrieval instance for testing.
Returns:
DatasetRetrieval: Instance to test process_metadata_filter_func
"""
return DatasetRetrieval()
@pytest.fixture
def mock_doc_metadata(self):
"""
Mock the DatasetDocument.doc_metadata JSON field.
The method uses DatasetDocument.doc_metadata[metadata_name] to access
JSON fields. We mock this to avoid database dependencies.
Returns:
Mock: Mocked doc_metadata attribute
"""
mock_metadata_field = MagicMock()
# Create mock for string access
mock_string_access = MagicMock()
mock_string_access.like = MagicMock()
mock_string_access.notlike = MagicMock()
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
mock_string_access.in_ = MagicMock(return_value=MagicMock())
# Create mock for float access (for numeric comparisons)
mock_float_access = MagicMock()
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
# Create mock for null checks
mock_null_access = MagicMock()
mock_null_access.is_ = MagicMock(return_value=MagicMock())
mock_null_access.isnot = MagicMock(return_value=MagicMock())
# Setup __getitem__ to return appropriate mock based on usage
def getitem_side_effect(name):
if name in ["author", "title", "category"]:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
# ==================== String Condition Tests ====================
def test_contains_condition_string_value(self, retrieval):
"""
Test 'contains' condition with string value.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value% syntax
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = "John"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_contains_condition(self, retrieval):
"""
Test 'not contains' condition.
Verifies:
- Filters list is populated with NOT LIKE expression
- Pattern matching uses %value% syntax with negation
"""
filters = []
sequence = 0
condition = "not contains"
metadata_name = "title"
value = "banned"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_start_with_condition(self, retrieval):
"""
Test 'start with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses value% syntax
"""
filters = []
sequence = 0
condition = "start with"
metadata_name = "category"
value = "tech"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_end_with_condition(self, retrieval):
"""
Test 'end with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value syntax
"""
filters = []
sequence = 0
condition = "end with"
metadata_name = "filename"
value = ".pdf"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Equality Condition Tests ====================
def test_is_condition_with_string_value(self, retrieval):
"""
Test 'is' (=) condition with string value.
Verifies:
- Filters list is populated with equality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = "Jane Doe"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_equals_condition_with_string_value(self, retrieval):
"""
Test '=' condition with string value.
Verifies:
- Same behavior as 'is' condition
- String comparison is used
"""
filters = []
sequence = 0
condition = "="
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_int_value(self, retrieval):
"""
Test 'is' condition with integer value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_float_value(self, retrieval):
"""
Test 'is' condition with float value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "price"
value = 19.99
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_string_value(self, retrieval):
"""
Test 'is not' () condition with string value.
Verifies:
- Filters list is populated with inequality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "author"
value = "Unknown"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_equals_condition(self, retrieval):
"""
Test '' condition with string value.
Verifies:
- Same behavior as 'is not' condition
- Inequality expression is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "category"
value = "archived"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_numeric_value(self, retrieval):
"""
Test 'is not' condition with numeric value.
Verifies:
- Numeric inequality comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Null Condition Tests ====================
def test_empty_condition(self, retrieval):
"""
Test 'empty' condition (null check).
Verifies:
- Filters list is populated with IS NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "empty"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_empty_condition(self, retrieval):
"""
Test 'not empty' condition (not null check).
Verifies:
- Filters list is populated with IS NOT NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "not empty"
metadata_name = "description"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Numeric Comparison Tests ====================
def test_before_condition(self, retrieval):
"""
Test 'before' (<) condition.
Verifies:
- Filters list is populated with less than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "before"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_condition(self, retrieval):
"""
Test '<' condition.
Verifies:
- Same behavior as 'before' condition
- Less than expression is used
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "price"
value = 100.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_after_condition(self, retrieval):
"""
Test 'after' (>) condition.
Verifies:
- Filters list is populated with greater than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "after"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_condition(self, retrieval):
"""
Test '>' condition.
Verifies:
- Same behavior as 'after' condition
- Greater than expression is used
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with less than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "price"
value = 50.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_ascii(self, retrieval):
"""
Test '<=' condition.
Verifies:
- Same behavior as '' condition
- Less than or equal expression is used
"""
filters = []
sequence = 0
condition = "<="
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with greater than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "rating"
value = 3.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_ascii(self, retrieval):
"""
Test '>=' condition.
Verifies:
- Same behavior as '' condition
- Greater than or equal expression is used
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== List/In Condition Tests ====================
def test_in_condition_with_comma_separated_string(self, retrieval):
"""
Test 'in' condition with comma-separated string value.
Verifies:
- String is split into list
- Whitespace is trimmed from each value
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "tech, science, AI "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_list_value(self, retrieval):
"""
Test 'in' condition with list value.
Verifies:
- List is processed correctly
- None values are filtered out
- IN expression is created with valid values
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "tags"
value = ["python", "javascript", None, "golang"]
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_tuple_value(self, retrieval):
"""
Test 'in' condition with tuple value.
Verifies:
- Tuple is processed like a list
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ("tech", "science", "ai")
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_empty_string(self, retrieval):
"""
Test 'in' condition with empty string value.
Verifies:
- Empty string results in literal(False) filter
- No valid values to match
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# Verify it's a literal(False) expression
# This is a bit tricky to test without access to the actual expression
def test_in_condition_with_only_whitespace(self, retrieval):
"""
Test 'in' condition with whitespace-only string value.
Verifies:
- Whitespace-only string results in literal(False) filter
- All values are stripped and filtered out
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = " , , "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_single_string(self, retrieval):
"""
Test 'in' condition with single non-comma string.
Verifies:
- Single string is treated as single-item list
- IN expression is created with one value
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Edge Case Tests ====================
def test_none_value_with_non_empty_condition(self, retrieval):
"""
Test None value with conditions that require value.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values (except empty/not empty)
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0 # No filter added
def test_none_value_with_equals_condition(self, retrieval):
"""
Test None value with 'is' (=) condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_none_value_with_numeric_condition(self, retrieval):
"""
Test None value with numeric comparison condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "year"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_existing_filters_preserved(self, retrieval):
"""
Test that existing filters are preserved.
Verifies:
- Existing filters in the list are not removed
- New filters are appended to the list
"""
existing_filter = MagicMock()
filters = [existing_filter]
sequence = 0
condition = "contains"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 2
assert filters[0] == existing_filter
def test_multiple_filters_accumulated(self, retrieval):
"""
Test multiple calls to accumulate filters.
Verifies:
- Each call adds a new filter to the list
- All filters are preserved across calls
"""
filters = []
# First filter
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
assert len(filters) == 1
# Second filter
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
assert len(filters) == 2
# Third filter
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
assert len(filters) == 3
def test_unknown_condition(self, retrieval):
"""
Test unknown/unsupported condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for unknown conditions
"""
filters = []
sequence = 0
condition = "unknown_condition"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_empty_string_value_with_contains(self, retrieval):
"""
Test empty string value with 'contains' condition.
Verifies:
- Filter is added even with empty string
- LIKE expression is created
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_special_characters_in_value(self, retrieval):
"""
Test special characters in value string.
Verifies:
- Special characters are handled in value
- LIKE expression is created correctly
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "title"
value = "C++ & Python's features"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_zero_value_with_numeric_condition(self, retrieval):
"""
Test zero value with numeric comparison condition.
Verifies:
- Zero is treated as valid value
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "price"
value = 0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_negative_value_with_numeric_condition(self, retrieval):
"""
Test negative value with numeric comparison condition.
Verifies:
- Negative numbers are handled correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "temperature"
value = -10.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_float_value_with_integer_comparison(self, retrieval):
"""
Test float value with numeric comparison condition.
Verifies:
- Float values work correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1

View File

@ -399,6 +399,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=*
COOKIE_DOMAIN=
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
NEXT_PUBLIC_COOKIE_DOMAIN=
NEXT_PUBLIC_BATCH_CONCURRENCY=5
# ------------------------------
# File Storage Configuration

View File

@ -108,6 +108,7 @@ x-shared-env: &shared-api-worker-env
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5}
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}

View File

@ -73,3 +73,6 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50
# The API key of amplitude
NEXT_PUBLIC_AMPLITUDE_API_KEY=
# number of concurrency
NEXT_PUBLIC_BATCH_CONCURRENCY=5

View File

@ -176,7 +176,7 @@ const DatasetConfig: FC = () => {
}))
}, [setDatasetConfigs, datasetConfigsRef])
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
let operator: ComparisonOperator = ComparisonOperator.is
if (type === MetadataFilteringVariableType.number)
@ -184,6 +184,7 @@ const DatasetConfig: FC = () => {
const newCondition = {
id: uuid4(),
metadata_id: id, // Save metadata.id for reliable reference
name,
comparison_operator: operator,
}

View File

@ -0,0 +1,141 @@
import type { DataSet } from '@/models/datasets'
import { act, fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { describe, expect, it, vi } from 'vitest'
import { IndexingType } from '@/app/components/datasets/create/step-two'
import { DatasetPermission } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import SelectDataSet from './index'
vi.mock('@/i18n-config/i18next-config', () => ({
__esModule: true,
default: {
changeLanguage: vi.fn(),
addResourceBundle: vi.fn(),
use: vi.fn().mockReturnThis(),
init: vi.fn(),
addResource: vi.fn(),
hasResourceBundle: vi.fn().mockReturnValue(true),
},
}))
const mockUseInfiniteScroll = vi.fn()
vi.mock('ahooks', async (importOriginal) => {
const actual = await importOriginal()
return {
...(typeof actual === 'object' && actual !== null ? actual : {}),
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
}
})
const mockUseInfiniteDatasets = vi.fn()
vi.mock('@/service/knowledge/use-dataset', () => ({
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
}))
vi.mock('@/hooks/use-knowledge', () => ({
useKnowledge: () => ({
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
}),
}))
const baseProps = {
isShow: true,
onClose: vi.fn(),
selectedIds: [] as string[],
onSelect: vi.fn(),
}
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
id: 'dataset-id',
name: 'Dataset Name',
provider: 'internal',
icon_info: {
icon_type: 'emoji',
icon: '💾',
icon_background: '#fff',
icon_url: '',
},
embedding_available: true,
is_multimodal: false,
description: '',
permission: DatasetPermission.allTeamMembers,
indexing_technique: IndexingType.ECONOMICAL,
retrieval_model_dict: {
search_method: RETRIEVE_METHOD.fullText,
top_k: 5,
reranking_enable: false,
reranking_model: {
reranking_model_name: '',
reranking_provider_name: '',
},
score_threshold_enabled: false,
score_threshold: 0,
},
...overrides,
} as DataSet)
describe('SelectDataSet', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('renders dataset entries, allows selection, and fires onSelect', async () => {
const datasetOne = makeDataset({
id: 'set-1',
name: 'Dataset One',
is_multimodal: true,
indexing_technique: IndexingType.ECONOMICAL,
})
const datasetTwo = makeDataset({
id: 'set-2',
name: 'Hidden Dataset',
embedding_available: false,
provider: 'external',
})
mockUseInfiniteDatasets.mockReturnValue({
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
isLoading: false,
isFetchingNextPage: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
})
const onSelect = vi.fn()
await act(async () => {
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
})
expect(screen.getByText('Dataset One')).toBeInTheDocument()
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
await act(async () => {
fireEvent.click(screen.getByText('Dataset One'))
})
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
await act(async () => {
fireEvent.click(addButton)
})
expect(onSelect).toHaveBeenCalledWith([datasetOne])
})
it('shows empty state when no datasets are available and disables add', async () => {
mockUseInfiniteDatasets.mockReturnValue({
data: { pages: [{ data: [] }] },
isLoading: false,
isFetchingNextPage: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
})
await act(async () => {
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
})
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
})
})

View File

@ -0,0 +1,125 @@
import type { IPromptValuePanelProps } from './index'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useStore } from '@/app/components/app/store'
import ConfigContext from '@/context/debug-configuration'
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
import PromptValuePanel from './index'
vi.mock('@/app/components/app/store', () => ({
useStore: vi.fn(),
}))
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
__esModule: true,
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
<button type="button" onClick={onFeatureBarClick}>
feature bar
</button>
),
}))
const mockSetShowAppConfigureFeaturesModal = vi.fn()
const mockUseStore = vi.mocked(useStore)
const mockSetInputs = vi.fn()
const mockOnSend = vi.fn()
const promptVariables = [
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
] as const
const baseContextValue: any = {
modelModeType: ModelModeType.completion,
modelConfig: {
configs: {
prompt_template: 'prompt template',
prompt_variables: promptVariables,
},
},
setInputs: mockSetInputs,
mode: AppModeEnum.COMPLETION,
isAdvancedMode: false,
completionPromptConfig: {
prompt: { text: 'completion' },
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
},
chatPromptConfig: { prompt: [] },
} as any
const defaultProps: IPromptValuePanelProps = {
appType: AppModeEnum.COMPLETION,
onSend: mockOnSend,
inputs: { textVar: 'initial', boolVar: false },
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
onVisionFilesChange: vi.fn(),
}
const renderPanel = (options: {
context?: Partial<typeof baseContextValue>
props?: Partial<IPromptValuePanelProps>
} = {}) => {
const contextValue = { ...baseContextValue, ...options.context }
const props = { ...defaultProps, ...options.props }
return render(
<ConfigContext.Provider value={contextValue}>
<PromptValuePanel {...props} />
</ConfigContext.Provider>,
)
}
describe('PromptValuePanel', () => {
beforeEach(() => {
mockUseStore.mockImplementation(selector => selector({
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
appSidebarExpand: '',
currentLogModalActiveTab: 'prompt',
showPromptLogModal: false,
showAgentLogModal: false,
setShowPromptLogModal: vi.fn(),
setShowAgentLogModal: vi.fn(),
showMessageLogModal: false,
showAppConfigureFeaturesModal: false,
} as any))
mockSetInputs.mockClear()
mockOnSend.mockClear()
mockSetShowAppConfigureFeaturesModal.mockClear()
})
it('updates inputs, clears values, and triggers run when ready', async () => {
renderPanel()
const textInput = screen.getByPlaceholderText('Text Var')
fireEvent.change(textInput, { target: { value: 'updated' } })
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
fireEvent.click(clearButton)
expect(mockSetInputs).toHaveBeenLastCalledWith({
textVar: '',
boolVar: '',
})
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
expect(runButton).not.toBeDisabled()
fireEvent.click(runButton)
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
})
it('disables run when mode is not completion', () => {
renderPanel({
context: {
mode: AppModeEnum.CHAT,
},
props: {
appType: AppModeEnum.CHAT,
},
})
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
expect(runButton).toBeDisabled()
fireEvent.click(runButton)
expect(mockOnSend).not.toHaveBeenCalled()
})
})

View File

@ -0,0 +1,29 @@
import type { PromptVariable } from '@/models/debug'
import { describe, expect, it } from 'vitest'
import { replaceStringWithValues } from './utils'
const promptVariables: PromptVariable[] = [
{ key: 'user', name: 'User', type: 'string' },
{ key: 'topic', name: 'Topic', type: 'string' },
]
describe('replaceStringWithValues', () => {
it('should replace placeholders when inputs have values', () => {
const template = 'Hello {{user}} talking about {{topic}}'
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
expect(result).toBe('Hello Alice talking about cats')
})
it('should use prompt variable name when value is missing', () => {
const template = 'Hi {{user}} from {{topic}}'
const result = replaceStringWithValues(template, promptVariables, {})
expect(result).toBe('Hi {{User}} from {{Topic}}')
})
it('should leave placeholder untouched when no variable is defined', () => {
const template = 'Unknown {{missing}} placeholder'
const result = replaceStringWithValues(template, promptVariables, {})
expect(result).toBe('Unknown {{missing}} placeholder')
})
})

View File

@ -0,0 +1,162 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useRouter } from 'next/navigation'
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
import { trackEvent } from '@/app/components/base/amplitude'
import { ToastContext } from '@/app/components/base/toast'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import CreateAppModal from './index'
vi.mock('ahooks', () => ({
useDebounceFn: (fn: (...args: any[]) => any) => {
const run = (...args: any[]) => fn(...args)
const cancel = vi.fn()
const flush = vi.fn()
return { run, cancel, flush }
},
useKeyPress: vi.fn(),
useHover: () => false,
}))
vi.mock('next/navigation', () => ({
useRouter: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/service/apps', () => ({
createApp: vi.fn(),
}))
vi.mock('@/utils/app-redirection', () => ({
getRedirection: vi.fn(),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useDocLink: () => () => '/guides',
}))
vi.mock('@/hooks/use-theme', () => ({
__esModule: true,
default: () => ({ theme: 'light' }),
}))
const mockNotify = vi.fn()
const mockUseRouter = vi.mocked(useRouter)
const mockPush = vi.fn()
const mockCreateApp = vi.mocked(createApp)
const mockTrackEvent = vi.mocked(trackEvent)
const mockGetRedirection = vi.mocked(getRedirection)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseAppContext = vi.mocked(useAppContext)
const defaultPlanUsage = {
buildApps: 0,
teamMembers: 0,
annotatedResponse: 0,
documentsUploadQuota: 0,
apiRateLimit: 0,
triggerEvents: 0,
vectorSpace: 0,
}
const renderModal = () => {
const onClose = vi.fn()
const onSuccess = vi.fn()
render(
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
</ToastContext.Provider>,
)
return { onClose, onSuccess }
}
describe('CreateAppModal', () => {
const mockSetItem = vi.fn()
const originalLocalStorage = window.localStorage
beforeEach(() => {
vi.clearAllMocks()
mockUseRouter.mockReturnValue({ push: mockPush } as any)
mockUseProviderContext.mockReturnValue({
plan: {
type: AppModeEnum.ADVANCED_CHAT,
usage: defaultPlanUsage,
total: { ...defaultPlanUsage, buildApps: 1 },
reset: {},
},
enableBilling: true,
} as any)
mockUseAppContext.mockReturnValue({
isCurrentWorkspaceEditor: true,
} as any)
mockSetItem.mockClear()
Object.defineProperty(window, 'localStorage', {
value: {
setItem: mockSetItem,
getItem: vi.fn(),
removeItem: vi.fn(),
clear: vi.fn(),
key: vi.fn(),
length: 0,
},
writable: true,
})
})
afterAll(() => {
Object.defineProperty(window, 'localStorage', {
value: originalLocalStorage,
writable: true,
})
})
it('creates an app, notifies success, and fires callbacks', async () => {
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
mockCreateApp.mockResolvedValue(mockApp as any)
const { onClose, onSuccess } = renderModal()
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
name: 'My App',
description: '',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#FFEAD5',
mode: AppModeEnum.ADVANCED_CHAT,
}))
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
app_mode: AppModeEnum.ADVANCED_CHAT,
description: '',
})
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
expect(onSuccess).toHaveBeenCalled()
expect(onClose).toHaveBeenCalled()
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
})
it('shows error toast when creation fails', async () => {
mockCreateApp.mockRejectedValue(new Error('boom'))
const { onClose } = renderModal()
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
expect(onClose).not.toHaveBeenCalled()
})
})

View File

@ -0,0 +1,121 @@
import type { SiteInfo } from '@/models/share'
import { fireEvent, render, screen } from '@testing-library/react'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { act } from 'react'
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
import Embedded from './index'
vi.mock('./style.module.css', () => ({
__esModule: true,
default: {
option: 'option',
active: 'active',
iframeIcon: 'iframeIcon',
scriptsIcon: 'scriptsIcon',
chromePluginIcon: 'chromePluginIcon',
pluginInstallIcon: 'pluginInstallIcon',
},
}))
const mockThemeBuilder = {
buildTheme: vi.fn(),
theme: {
primaryColor: '#123456',
},
}
const mockUseAppContext = vi.fn(() => ({
langGeniusVersionInfo: {
current_env: 'PRODUCTION',
current_version: '',
latest_version: '',
release_date: '',
release_notes: '',
version: '',
can_auto_update: false,
},
}))
vi.mock('copy-to-clipboard', () => ({
__esModule: true,
default: vi.fn(),
}))
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
useThemeContext: () => mockThemeBuilder,
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => mockUseAppContext(),
}))
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
const mockedCopy = vi.mocked(copy)
const siteInfo: SiteInfo = {
title: 'test site',
chat_color_theme: '#000000',
chat_color_theme_inverted: false,
}
const baseProps = {
isShow: true,
siteInfo,
onClose: vi.fn(),
appBaseUrl: 'https://app.example.com',
accessToken: 'token',
className: 'custom-modal',
}
const getCopyButton = () => {
const buttons = screen.getAllByRole('button')
const actionButton = buttons.find(button => button.className.includes('action-btn'))
expect(actionButton).toBeDefined()
return actionButton!
}
describe('Embedded', () => {
afterEach(() => {
vi.clearAllMocks()
mockWindowOpen.mockClear()
})
afterAll(() => {
mockWindowOpen.mockRestore()
})
it('builds theme and copies iframe snippet', async () => {
await act(async () => {
render(<Embedded {...baseProps} />)
})
const actionButton = getCopyButton()
const innerDiv = actionButton.querySelector('div')
act(() => {
fireEvent.click(innerDiv ?? actionButton)
})
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
})
it('opens chrome plugin store link when chrome option selected', async () => {
await act(async () => {
render(<Embedded {...baseProps} />)
})
const optionButtons = document.body.querySelectorAll('[class*="option"]')
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
act(() => {
fireEvent.click(optionButtons[2])
})
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
act(() => {
fireEvent.click(chromeText)
})
expect(mockWindowOpen).toHaveBeenCalledWith(
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
'_blank',
'noopener,noreferrer',
)
})
})

View File

@ -0,0 +1,67 @@
import type { ISavedItemsProps } from './index'
import { fireEvent, render, screen } from '@testing-library/react'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import SavedItems from './index'
vi.mock('copy-to-clipboard', () => ({
__esModule: true,
default: vi.fn(),
}))
vi.mock('next/navigation', () => ({
useParams: () => ({}),
usePathname: () => '/',
}))
const mockCopy = vi.mocked(copy)
const toastNotifySpy = vi.spyOn(Toast, 'notify')
const baseProps: ISavedItemsProps = {
list: [
{ id: '1', answer: 'hello world' },
],
isShowTextToSpeech: true,
onRemove: vi.fn(),
onStartCreateContent: vi.fn(),
}
describe('SavedItems', () => {
beforeEach(() => {
vi.clearAllMocks()
toastNotifySpy.mockClear()
})
it('renders saved answers with metadata and controls', () => {
const { container } = render(<SavedItems {...baseProps} />)
const markdownElement = container.querySelector('.markdown-body')
expect(markdownElement).toBeInTheDocument()
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
const actionButtons = actionArea?.querySelectorAll('button') ?? []
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
})
it('copies content and notifies, and triggers remove callback', () => {
const handleRemove = vi.fn()
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
const actionButtons = actionArea?.querySelectorAll('button') ?? []
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
const copyButton = actionButtons[1]
const deleteButton = actionButtons[2]
fireEvent.click(copyButton)
expect(mockCopy).toHaveBeenCalledWith('hello world')
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
fireEvent.click(deleteButton)
expect(handleRemove).toHaveBeenCalledWith('1')
})
})

View File

@ -0,0 +1,22 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import NoData from './index'
describe('NoData', () => {
it('renders title/description and calls callback when button clicked', () => {
const handleStart = vi.fn()
render(<NoData onStartCreateContent={handleStart} />)
const title = screen.getByText('share.generation.savedNoData.title')
const description = screen.getByText('share.generation.savedNoData.description')
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
expect(button).toBeInTheDocument()
fireEvent.click(button)
expect(handleStart).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,147 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils'
import { useToastContext } from '@/app/components/base/toast'
import { Plan } from '@/app/components/billing/type'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useProviderContext } from '@/context/provider-context'
import { updateCurrentWorkspace } from '@/service/common'
import CustomWebAppBrand from './index'
vi.mock('@/app/components/base/toast', () => ({
useToastContext: vi.fn(),
}))
vi.mock('@/service/common', () => ({
updateCurrentWorkspace: vi.fn(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('@/app/components/base/image-uploader/utils', () => ({
imageUpload: vi.fn(),
getImageUploadErrorMessage: vi.fn(),
}))
const mockNotify = vi.fn()
const mockUseToastContext = vi.mocked(useToastContext)
const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace)
const mockUseAppContext = vi.mocked(useAppContext)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
const mockImageUpload = vi.mocked(imageUpload)
const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage)
const defaultPlanUsage = {
buildApps: 0,
teamMembers: 0,
annotatedResponse: 0,
documentsUploadQuota: 0,
apiRateLimit: 0,
triggerEvents: 0,
vectorSpace: 0,
}
const renderComponent = () => render(<CustomWebAppBrand />)
describe('CustomWebAppBrand', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseToastContext.mockReturnValue({ notify: mockNotify } as any)
mockUpdateCurrentWorkspace.mockResolvedValue({} as any)
mockUseAppContext.mockReturnValue({
currentWorkspace: {
custom_config: {
replace_webapp_logo: 'https://example.com/replace.png',
remove_webapp_brand: false,
},
},
mutateCurrentWorkspace: vi.fn(),
isCurrentWorkspaceManager: true,
} as any)
mockUseProviderContext.mockReturnValue({
plan: {
type: Plan.professional,
usage: defaultPlanUsage,
total: defaultPlanUsage,
reset: {},
},
enableBilling: false,
} as any)
const systemFeaturesState = {
branding: {
enabled: true,
workspace_logo: 'https://example.com/workspace-logo.png',
},
}
mockUseGlobalPublicStore.mockImplementation(selector => selector ? selector({ systemFeatures: systemFeaturesState } as any) : { systemFeatures: systemFeaturesState })
mockGetImageUploadErrorMessage.mockReturnValue('upload error')
})
it('disables upload controls when the user cannot manage the workspace', () => {
mockUseAppContext.mockReturnValue({
currentWorkspace: {
custom_config: {
replace_webapp_logo: '',
remove_webapp_brand: false,
},
},
mutateCurrentWorkspace: vi.fn(),
isCurrentWorkspaceManager: false,
} as any)
const { container } = renderComponent()
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
expect(fileInput).toBeDisabled()
})
it('toggles remove brand switch and calls the backend + mutate', async () => {
const mutateMock = vi.fn()
mockUseAppContext.mockReturnValue({
currentWorkspace: {
custom_config: {
replace_webapp_logo: '',
remove_webapp_brand: false,
},
},
mutateCurrentWorkspace: mutateMock,
isCurrentWorkspaceManager: true,
} as any)
renderComponent()
const switchInput = screen.getByRole('switch')
fireEvent.click(switchInput)
await waitFor(() => expect(mockUpdateCurrentWorkspace).toHaveBeenCalledWith({
url: '/workspaces/custom-config',
body: { remove_webapp_brand: true },
}))
await waitFor(() => expect(mutateMock).toHaveBeenCalled())
})
it('shows cancel/apply buttons after successful upload and cancels properly', async () => {
mockImageUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }) => {
onProgressCallback(50)
onSuccessCallback({ id: 'new-logo' })
})
const { container } = renderComponent()
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
const testFile = new File(['content'], 'logo.png', { type: 'image/png' })
fireEvent.change(fileInput, { target: { files: [testFile] } })
await waitFor(() => expect(mockImageUpload).toHaveBeenCalled())
await waitFor(() => screen.getByRole('button', { name: 'custom.apply' }))
const cancelButton = screen.getByRole('button', { name: 'common.operation.cancel' })
fireEvent.click(cancelButton)
await waitFor(() => expect(screen.queryByRole('button', { name: 'custom.apply' })).toBeNull())
})
})

View File

@ -26,7 +26,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo'
import Toast from '@/app/components/base/toast'
import Res from '@/app/components/share/text-generation/result'
import RunOnce from '@/app/components/share/text-generation/run-once'
import { appDefaultIconBackground, DEFAULT_VALUE_MAX_LEN } from '@/config'
import { appDefaultIconBackground, BATCH_CONCURRENCY, DEFAULT_VALUE_MAX_LEN } from '@/config'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useWebAppStore } from '@/context/web-app-context'
import { useAppFavicon } from '@/hooks/use-app-favicon'
@ -43,7 +43,7 @@ import MenuDropdown from './menu-dropdown'
import RunBatch from './run-batch'
import ResDownload from './run-batch/res-download'
const GROUP_SIZE = 5 // to avoid RPM(Request per minute) limit. The group task finished then the next group.
const GROUP_SIZE = BATCH_CONCURRENCY // to avoid RPM(Request per minute) limit. The group task finished then the next group.
enum TaskStatus {
pending = 'pending',
running = 'running',

View File

@ -62,8 +62,15 @@ const ConditionItem = ({
}, [onRemoveCondition, condition.id])
const currentMetadata = useMemo(() => {
// Try to match by metadata_id first (reliable reference)
if (condition.metadata_id) {
const found = metadataList.find(metadata => metadata.id === condition.metadata_id)
if (found)
return found
}
// Fallback to name matching for backward compatibility with old conditions
return metadataList.find(metadata => metadata.name === condition.name)
}, [metadataList, condition.name])
}, [metadataList, condition.metadata_id, condition.name])
const handleConditionOperatorChange = useCallback((operator: ComparisonOperator) => {
onUpdateCondition?.(

View File

@ -27,11 +27,17 @@ const MetadataTrigger = ({
useEffect(() => {
if (selectedDatasetsLoaded) {
conditions.forEach((condition) => {
if (!metadataList.find(metadata => metadata.name === condition.name))
// First try to match by metadata_id for reliable reference
const foundById = condition.metadata_id && metadataList.find(metadata => metadata.id === condition.metadata_id)
// Fallback to name matching only for backward compatibility with old conditions
const foundByName = !condition.metadata_id && metadataList.find(metadata => metadata.name === condition.name)
// Only remove condition if both metadata_id and name matching fail
if (!foundById && !foundByName)
handleRemoveCondition(condition.id)
})
}
}, [metadataList, handleRemoveCondition, selectedDatasetsLoaded])
}, [metadataFilteringConditions, metadataList, handleRemoveCondition, selectedDatasetsLoaded])
return (
<PortalToFollowElem

View File

@ -86,6 +86,7 @@ export enum MetadataFilteringVariableType {
export type MetadataFilteringCondition = {
id: string
name: string
metadata_id?: string
comparison_operator: ComparisonOperator
value?: string | number
}

View File

@ -305,7 +305,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}))
}, [setInputs])
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
let operator: ComparisonOperator = ComparisonOperator.is
if (type === MetadataFilteringVariableType.number)
@ -313,6 +313,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const newCondition = {
id: uuid4(),
metadata_id: id, // Save metadata.id for reliable reference
name,
comparison_operator: operator,
}

View File

@ -67,6 +67,7 @@ const LocaleLayout = async ({
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL,
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID,
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN,
[DatasetAttr.DATA_PUBLIC_BATCH_CONCURRENCY]: process.env.NEXT_PUBLIC_BATCH_CONCURRENCY,
}
return (

View File

@ -164,6 +164,13 @@ const COOKIE_DOMAIN = getStringConfig(
DatasetAttr.DATA_PUBLIC_COOKIE_DOMAIN,
'',
).trim()
export const BATCH_CONCURRENCY = getNumberConfig(
process.env.NEXT_PUBLIC_BATCH_CONCURRENCY,
DatasetAttr.DATA_PUBLIC_BATCH_CONCURRENCY,
5, // default
)
export const CSRF_COOKIE_NAME = () => {
if (COOKIE_DOMAIN)
return 'csrf_token'

View File

@ -131,4 +131,5 @@ export enum DatasetAttr {
NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL = 'next-public-zendesk-field-id-email',
NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID = 'next-public-zendesk-field-id-workspace-id',
NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN = 'next-public-zendesk-field-id-plan',
DATA_PUBLIC_BATCH_CONCURRENCY = 'data-public-batch-concurrency',
}