mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
test: migrate dataset service update-dataset SQL tests to testcontainers (#32533)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
b48f36a4e5
commit
5cb1b53b47
@ -0,0 +1,529 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.errors.account import NoPermissionError
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetUpdateTestDataFactory:
|
||||||
|
"""Factory class for creating real test data for dataset update integration tests."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
|
||||||
|
"""Create a real account and tenant with the given role."""
|
||||||
|
account = Account(
|
||||||
|
email=f"{uuid4()}@example.com",
|
||||||
|
name=f"user-{uuid4()}",
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db.session.add(account)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
|
||||||
|
db.session.add(tenant)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=role,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db.session.add(join)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
account.current_tenant = tenant
|
||||||
|
return account, tenant
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_dataset(
|
||||||
|
tenant_id: str,
|
||||||
|
created_by: str,
|
||||||
|
provider: str = "vendor",
|
||||||
|
name: str = "old_name",
|
||||||
|
description: str = "old_description",
|
||||||
|
indexing_technique: str = "high_quality",
|
||||||
|
retrieval_model: str = "old_model",
|
||||||
|
permission: str = "only_me",
|
||||||
|
embedding_model_provider: str | None = None,
|
||||||
|
embedding_model: str | None = None,
|
||||||
|
collection_binding_id: str | None = None,
|
||||||
|
) -> Dataset:
|
||||||
|
"""Create a real dataset."""
|
||||||
|
dataset = Dataset(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
indexing_technique=indexing_technique,
|
||||||
|
created_by=created_by,
|
||||||
|
provider=provider,
|
||||||
|
retrieval_model=retrieval_model,
|
||||||
|
permission=permission,
|
||||||
|
embedding_model_provider=embedding_model_provider,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
collection_binding_id=collection_binding_id,
|
||||||
|
)
|
||||||
|
db.session.add(dataset)
|
||||||
|
db.session.commit()
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_external_binding(
|
||||||
|
tenant_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
created_by: str,
|
||||||
|
external_knowledge_id: str = "old_knowledge_id",
|
||||||
|
external_knowledge_api_id: str | None = None,
|
||||||
|
) -> ExternalKnowledgeBindings:
|
||||||
|
"""Create a real external knowledge binding."""
|
||||||
|
if external_knowledge_api_id is None:
|
||||||
|
external_knowledge_api_id = str(uuid4())
|
||||||
|
binding = ExternalKnowledgeBindings(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
created_by=created_by,
|
||||||
|
external_knowledge_id=external_knowledge_id,
|
||||||
|
external_knowledge_api_id=external_knowledge_api_id,
|
||||||
|
)
|
||||||
|
db.session.add(binding)
|
||||||
|
db.session.commit()
|
||||||
|
return binding
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetServiceUpdateDataset:
|
||||||
|
"""
|
||||||
|
Comprehensive integration tests for DatasetService.update_dataset method.
|
||||||
|
|
||||||
|
This test suite covers all supported scenarios including:
|
||||||
|
- External dataset updates
|
||||||
|
- Internal dataset updates with different indexing techniques
|
||||||
|
- Embedding model updates
|
||||||
|
- Permission checks
|
||||||
|
- Error conditions and edge cases
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ==================== External Dataset Tests ====================
|
||||||
|
|
||||||
|
def test_update_external_dataset_success(self, db_session_with_containers):
|
||||||
|
"""Test successful update of external dataset."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="external",
|
||||||
|
name="old_name",
|
||||||
|
description="old_description",
|
||||||
|
retrieval_model="old_model",
|
||||||
|
)
|
||||||
|
binding = DatasetUpdateTestDataFactory.create_external_binding(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
created_by=user.id,
|
||||||
|
)
|
||||||
|
binding_id = binding.id
|
||||||
|
db.session.expunge(binding)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": "new_description",
|
||||||
|
"external_retrieval_model": "new_model",
|
||||||
|
"permission": "only_me",
|
||||||
|
"external_knowledge_id": "new_knowledge_id",
|
||||||
|
"external_knowledge_api_id": str(uuid4()),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
|
||||||
|
|
||||||
|
assert dataset.name == "new_name"
|
||||||
|
assert dataset.description == "new_description"
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert updated_binding is not None
|
||||||
|
assert updated_binding.external_knowledge_id == "new_knowledge_id"
|
||||||
|
assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
|
||||||
|
"""Test error when external knowledge id is missing."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="external",
|
||||||
|
)
|
||||||
|
DatasetUpdateTestDataFactory.create_external_binding(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
created_by=user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {"name": "new_name", "external_knowledge_api_id": str(uuid4())}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
assert "External knowledge id is required" in str(context.value)
|
||||||
|
db.session.rollback()
|
||||||
|
|
||||||
|
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers):
|
||||||
|
"""Test error when external knowledge api id is missing."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="external",
|
||||||
|
)
|
||||||
|
DatasetUpdateTestDataFactory.create_external_binding(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
created_by=user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
assert "External knowledge api id is required" in str(context.value)
|
||||||
|
db.session.rollback()
|
||||||
|
|
||||||
|
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers):
|
||||||
|
"""Test error when external knowledge binding is not found."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="external",
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"external_knowledge_id": "knowledge_id",
|
||||||
|
"external_knowledge_api_id": str(uuid4()),
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
assert "External knowledge binding not found" in str(context.value)
|
||||||
|
db.session.rollback()
|
||||||
|
|
||||||
|
# ==================== Internal Dataset Basic Tests ====================
|
||||||
|
|
||||||
|
def test_update_internal_dataset_basic_success(self, db_session_with_containers):
|
||||||
|
"""Test successful update of internal dataset with basic fields."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
existing_binding_id = str(uuid4())
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
embedding_model_provider="openai",
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
collection_binding_id=existing_binding_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": "new_description",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
|
||||||
|
assert dataset.name == "new_name"
|
||||||
|
assert dataset.description == "new_description"
|
||||||
|
assert dataset.indexing_technique == "high_quality"
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert dataset.embedding_model_provider == "openai"
|
||||||
|
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers):
|
||||||
|
"""Test that None values are filtered out except for description field."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
existing_binding_id = str(uuid4())
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
embedding_model_provider="openai",
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
collection_binding_id=existing_binding_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": None,
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"embedding_model_provider": None,
|
||||||
|
"embedding_model": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
|
||||||
|
assert dataset.name == "new_name"
|
||||||
|
assert dataset.description is None
|
||||||
|
assert dataset.embedding_model_provider == "openai"
|
||||||
|
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
# ==================== Indexing Technique Switch Tests ====================
|
||||||
|
|
||||||
|
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers):
|
||||||
|
"""Test updating internal dataset indexing technique to economy."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
existing_binding_id = str(uuid4())
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
embedding_model_provider="openai",
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
collection_binding_id=existing_binding_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "economy",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task:
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
mock_task.delay.assert_called_once_with(dataset.id, "remove")
|
||||||
|
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
assert dataset.indexing_technique == "economy"
|
||||||
|
assert dataset.embedding_model is None
|
||||||
|
assert dataset.embedding_model_provider is None
|
||||||
|
assert dataset.collection_binding_id is None
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers):
|
||||||
|
"""Test updating internal dataset indexing technique to high_quality."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="economy",
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_model = Mock()
|
||||||
|
embedding_model.model = "text-embedding-ada-002"
|
||||||
|
embedding_model.provider = "openai"
|
||||||
|
|
||||||
|
binding = Mock()
|
||||||
|
binding.id = str(uuid4())
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("services.dataset_service.current_user", user),
|
||||||
|
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||||
|
patch(
|
||||||
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||||
|
) as mock_get_binding,
|
||||||
|
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||||
|
):
|
||||||
|
mock_model_manager.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
mock_get_binding.return_value = binding
|
||||||
|
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
mock_model_manager.return_value.get_model_instance.assert_called_once_with(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider="openai",
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002")
|
||||||
|
mock_task.delay.assert_called_once_with(dataset.id, "add")
|
||||||
|
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
assert dataset.indexing_technique == "high_quality"
|
||||||
|
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||||
|
assert dataset.embedding_model_provider == "openai"
|
||||||
|
assert dataset.collection_binding_id == binding.id
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
# ==================== Embedding Model Update Tests ====================
|
||||||
|
|
||||||
|
def test_update_internal_dataset_keep_existing_embedding_model_when_indexing_technique_unchanged(
|
||||||
|
self, db_session_with_containers
|
||||||
|
):
|
||||||
|
"""Test preserving embedding settings when indexing technique remains unchanged."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
existing_binding_id = str(uuid4())
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
embedding_model_provider="openai",
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
collection_binding_id=existing_binding_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
|
||||||
|
assert dataset.name == "new_name"
|
||||||
|
assert dataset.indexing_technique == "high_quality"
|
||||||
|
assert dataset.embedding_model_provider == "openai"
|
||||||
|
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||||
|
assert dataset.collection_binding_id == existing_binding_id
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers):
|
||||||
|
"""Test updating internal dataset with new embedding model."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
existing_binding_id = str(uuid4())
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
embedding_model_provider="openai",
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
collection_binding_id=existing_binding_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_model = Mock()
|
||||||
|
embedding_model.model = "text-embedding-3-small"
|
||||||
|
embedding_model.provider = "openai"
|
||||||
|
|
||||||
|
binding = Mock()
|
||||||
|
binding.id = str(uuid4())
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-3-small",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("services.dataset_service.current_user", user),
|
||||||
|
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||||
|
patch(
|
||||||
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||||
|
) as mock_get_binding,
|
||||||
|
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||||
|
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
|
||||||
|
):
|
||||||
|
mock_model_manager.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
mock_get_binding.return_value = binding
|
||||||
|
|
||||||
|
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
mock_model_manager.return_value.get_model_instance.assert_called_once_with(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider="openai",
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
mock_get_binding.assert_called_once_with("openai", "text-embedding-3-small")
|
||||||
|
mock_task.delay.assert_called_once_with(dataset.id, "update")
|
||||||
|
mock_regenerate_task.delay.assert_called_once_with(
|
||||||
|
dataset.id,
|
||||||
|
regenerate_reason="embedding_model_changed",
|
||||||
|
regenerate_vectors_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
assert dataset.embedding_model == "text-embedding-3-small"
|
||||||
|
assert dataset.embedding_model_provider == "openai"
|
||||||
|
assert dataset.collection_binding_id == binding.id
|
||||||
|
assert dataset.retrieval_model == "new_model"
|
||||||
|
assert result.id == dataset.id
|
||||||
|
|
||||||
|
# ==================== Error Handling Tests ====================
|
||||||
|
|
||||||
|
def test_update_dataset_not_found_error(self, db_session_with_containers):
|
||||||
|
"""Test error when dataset is not found."""
|
||||||
|
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
update_data = {"name": "new_name"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset(str(uuid4()), update_data, user)
|
||||||
|
|
||||||
|
assert "Dataset not found" in str(context.value)
|
||||||
|
|
||||||
|
def test_update_dataset_permission_error(self, db_session_with_containers):
|
||||||
|
"""Test error when user doesn't have permission."""
|
||||||
|
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||||
|
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=owner.id,
|
||||||
|
provider="vendor",
|
||||||
|
permission="only_me",
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {"name": "new_name"}
|
||||||
|
|
||||||
|
with pytest.raises(NoPermissionError):
|
||||||
|
DatasetService.update_dataset(dataset.id, update_data, outsider)
|
||||||
|
|
||||||
|
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers):
|
||||||
|
"""Test error when embedding model is not available."""
|
||||||
|
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||||
|
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
created_by=user.id,
|
||||||
|
provider="vendor",
|
||||||
|
indexing_technique="economy",
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "invalid_provider",
|
||||||
|
"embedding_model": "invalid_model",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("services.dataset_service.current_user", user),
|
||||||
|
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||||
|
):
|
||||||
|
mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available")
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as context:
|
||||||
|
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||||
|
|
||||||
|
assert "No Embedding Model available".lower() in str(context.value).lower()
|
||||||
@ -1,661 +0,0 @@
|
|||||||
import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
# Mock redis_client before importing dataset_service
|
|
||||||
from unittest.mock import Mock, create_autospec, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
|
||||||
from services.dataset_service import DatasetService
|
|
||||||
from services.errors.account import NoPermissionError
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetUpdateTestDataFactory:
|
|
||||||
"""Factory class for creating test data and mock objects for dataset update tests."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_dataset_mock(
|
|
||||||
dataset_id: str = "dataset-123",
|
|
||||||
provider: str = "vendor",
|
|
||||||
name: str = "old_name",
|
|
||||||
description: str = "old_description",
|
|
||||||
indexing_technique: str = "high_quality",
|
|
||||||
retrieval_model: str = "old_model",
|
|
||||||
embedding_model_provider: str | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
collection_binding_id: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Mock:
|
|
||||||
"""Create a mock dataset with specified attributes."""
|
|
||||||
dataset = Mock(spec=Dataset)
|
|
||||||
dataset.id = dataset_id
|
|
||||||
dataset.provider = provider
|
|
||||||
dataset.name = name
|
|
||||||
dataset.description = description
|
|
||||||
dataset.indexing_technique = indexing_technique
|
|
||||||
dataset.retrieval_model = retrieval_model
|
|
||||||
dataset.embedding_model_provider = embedding_model_provider
|
|
||||||
dataset.embedding_model = embedding_model
|
|
||||||
dataset.collection_binding_id = collection_binding_id
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(dataset, key, value)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_user_mock(user_id: str = "user-789") -> Mock:
|
|
||||||
"""Create a mock user."""
|
|
||||||
user = Mock()
|
|
||||||
user.id = user_id
|
|
||||||
return user
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_external_binding_mock(
|
|
||||||
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
|
|
||||||
) -> Mock:
|
|
||||||
"""Create a mock external knowledge binding."""
|
|
||||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
|
||||||
binding.external_knowledge_id = external_knowledge_id
|
|
||||||
binding.external_knowledge_api_id = external_knowledge_api_id
|
|
||||||
return binding
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
|
||||||
"""Create a mock embedding model."""
|
|
||||||
embedding_model = Mock()
|
|
||||||
embedding_model.model = model
|
|
||||||
embedding_model.provider = provider
|
|
||||||
return embedding_model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
|
|
||||||
"""Create a mock collection binding."""
|
|
||||||
binding = Mock()
|
|
||||||
binding.id = binding_id
|
|
||||||
return binding
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
|
|
||||||
"""Create a mock current user."""
|
|
||||||
current_user = create_autospec(Account, instance=True)
|
|
||||||
current_user.current_tenant_id = tenant_id
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetServiceUpdateDataset:
|
|
||||||
"""
|
|
||||||
Comprehensive unit tests for DatasetService.update_dataset method.
|
|
||||||
|
|
||||||
This test suite covers all supported scenarios including:
|
|
||||||
- External dataset updates
|
|
||||||
- Internal dataset updates with different indexing techniques
|
|
||||||
- Embedding model updates
|
|
||||||
- Permission checks
|
|
||||||
- Error conditions and edge cases
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_dataset_service_dependencies(self):
|
|
||||||
"""Common mock setup for dataset service dependencies."""
|
|
||||||
with (
|
|
||||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
|
||||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
|
||||||
patch("extensions.ext_database.db.session") as mock_db,
|
|
||||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
|
||||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
|
||||||
):
|
|
||||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
|
||||||
mock_naive_utc_now.return_value = current_time
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"get_dataset": mock_get_dataset,
|
|
||||||
"check_permission": mock_check_perm,
|
|
||||||
"db_session": mock_db,
|
|
||||||
"naive_utc_now": mock_naive_utc_now,
|
|
||||||
"current_time": current_time,
|
|
||||||
"has_dataset_same_name": has_dataset_same_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_external_provider_dependencies(self):
|
|
||||||
"""Mock setup for external provider tests."""
|
|
||||||
with patch("services.dataset_service.Session") as mock_session:
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
with patch.object(db.__class__, "engine", new_callable=Mock):
|
|
||||||
session_mock = Mock()
|
|
||||||
mock_session.return_value.__enter__.return_value = session_mock
|
|
||||||
yield session_mock
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_internal_provider_dependencies(self):
|
|
||||||
"""Mock setup for internal provider tests."""
|
|
||||||
with (
|
|
||||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
|
||||||
patch(
|
|
||||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
|
||||||
) as mock_get_binding,
|
|
||||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
|
||||||
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
|
|
||||||
patch(
|
|
||||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
|
||||||
) as mock_current_user,
|
|
||||||
):
|
|
||||||
mock_current_user.current_tenant_id = "tenant-123"
|
|
||||||
yield {
|
|
||||||
"model_manager": mock_model_manager,
|
|
||||||
"get_binding": mock_get_binding,
|
|
||||||
"task": mock_task,
|
|
||||||
"regenerate_task": mock_regenerate_task,
|
|
||||||
"current_user": mock_current_user,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
|
|
||||||
"""Helper method to verify database update calls."""
|
|
||||||
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
|
|
||||||
mock_db.commit.assert_called_once()
|
|
||||||
|
|
||||||
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
|
|
||||||
"""Helper method to verify external dataset updates."""
|
|
||||||
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
|
|
||||||
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
|
|
||||||
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
|
|
||||||
|
|
||||||
if "external_knowledge_id" in update_data:
|
|
||||||
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
|
|
||||||
if "external_knowledge_api_id" in update_data:
|
|
||||||
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
|
||||||
|
|
||||||
# ==================== External Dataset Tests ====================
|
|
||||||
|
|
||||||
def test_update_external_dataset_success(
|
|
||||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
|
||||||
):
|
|
||||||
"""Test successful update of external dataset."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
|
||||||
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
|
|
||||||
)
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
|
|
||||||
|
|
||||||
# Mock external knowledge binding query
|
|
||||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"description": "new_description",
|
|
||||||
"external_retrieval_model": "new_model",
|
|
||||||
"permission": "only_me",
|
|
||||||
"external_knowledge_id": "new_knowledge_id",
|
|
||||||
"external_knowledge_api_id": "new_api_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
|
||||||
|
|
||||||
# Verify dataset and binding updates
|
|
||||||
self._assert_external_dataset_update(dataset, binding, update_data)
|
|
||||||
|
|
||||||
# Verify database operations
|
|
||||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
|
||||||
mock_db.add.assert_any_call(dataset)
|
|
||||||
mock_db.add.assert_any_call(binding)
|
|
||||||
mock_db.commit.assert_called_once()
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test error when external knowledge id is missing."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
assert "External knowledge id is required" in str(context.value)
|
|
||||||
|
|
||||||
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test error when external knowledge api id is missing."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
assert "External knowledge api id is required" in str(context.value)
|
|
||||||
|
|
||||||
def test_update_external_dataset_binding_not_found_error(
|
|
||||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
|
||||||
):
|
|
||||||
"""Test error when external knowledge binding is not found."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
# Mock external knowledge binding query returning None
|
|
||||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"external_knowledge_id": "knowledge_id",
|
|
||||||
"external_knowledge_api_id": "api_id",
|
|
||||||
}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
assert "External knowledge binding not found" in str(context.value)
|
|
||||||
|
|
||||||
# ==================== Internal Dataset Basic Tests ====================
|
|
||||||
|
|
||||||
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test successful update of internal dataset with basic fields."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
|
||||||
provider="vendor",
|
|
||||||
indexing_technique="high_quality",
|
|
||||||
embedding_model_provider="openai",
|
|
||||||
embedding_model="text-embedding-ada-002",
|
|
||||||
collection_binding_id="binding-123",
|
|
||||||
)
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"description": "new_description",
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"embedding_model": "text-embedding-ada-002",
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify permission check was called
|
|
||||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
|
||||||
|
|
||||||
# Verify database update was called with correct filtered data
|
|
||||||
expected_filtered_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"description": "new_description",
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"embedding_model": "text-embedding-ada-002",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
self._assert_database_update_called(
|
|
||||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test that None values are filtered out except for description field."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"description": None, # Should be included
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"embedding_model_provider": None, # Should be filtered out
|
|
||||||
"embedding_model": None, # Should be filtered out
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify database update was called with filtered data
|
|
||||||
expected_filtered_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"description": None, # Description should be included even if None
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
actual_call_args = mock_dataset_service_dependencies[
|
|
||||||
"db_session"
|
|
||||||
].query.return_value.filter_by.return_value.update.call_args[0][0]
|
|
||||||
# Remove timestamp for comparison as it's dynamic
|
|
||||||
del actual_call_args["updated_at"]
|
|
||||||
del expected_filtered_data["updated_at"]
|
|
||||||
|
|
||||||
assert actual_call_args == expected_filtered_data
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
# ==================== Indexing Technique Switch Tests ====================
|
|
||||||
|
|
||||||
def test_update_internal_dataset_indexing_technique_to_economy(
|
|
||||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
|
||||||
):
|
|
||||||
"""Test updating internal dataset indexing technique to economy."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify database update was called with embedding model fields cleared
|
|
||||||
expected_filtered_data = {
|
|
||||||
"indexing_technique": "economy",
|
|
||||||
"embedding_model": None,
|
|
||||||
"embedding_model_provider": None,
|
|
||||||
"collection_binding_id": None,
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
self._assert_database_update_called(
|
|
||||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
def test_update_internal_dataset_indexing_technique_to_high_quality(
|
|
||||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
|
||||||
):
|
|
||||||
"""Test updating internal dataset indexing technique to high_quality."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
# Mock embedding model
|
|
||||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
|
|
||||||
mock_internal_provider_dependencies[
|
|
||||||
"model_manager"
|
|
||||||
].return_value.get_model_instance.return_value = embedding_model
|
|
||||||
|
|
||||||
# Mock collection binding
|
|
||||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
|
|
||||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"embedding_model": "text-embedding-ada-002",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify embedding model was validated
|
|
||||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
|
||||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
|
||||||
provider="openai",
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model="text-embedding-ada-002",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify collection binding was retrieved
|
|
||||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
|
|
||||||
|
|
||||||
# Verify database update was called with correct data
|
|
||||||
expected_filtered_data = {
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model": "text-embedding-ada-002",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"collection_binding_id": "binding-456",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
self._assert_database_update_called(
|
|
||||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify vector index task was triggered
|
|
||||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
# ==================== Embedding Model Update Tests ====================
|
|
||||||
|
|
||||||
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test updating internal dataset without changing embedding model."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
|
||||||
provider="vendor",
|
|
||||||
indexing_technique="high_quality",
|
|
||||||
embedding_model_provider="openai",
|
|
||||||
embedding_model="text-embedding-ada-002",
|
|
||||||
collection_binding_id="binding-123",
|
|
||||||
)
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify database update was called with existing embedding model preserved
|
|
||||||
expected_filtered_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"embedding_model": "text-embedding-ada-002",
|
|
||||||
"collection_binding_id": "binding-123",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
self._assert_database_update_called(
|
|
||||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
def test_update_internal_dataset_embedding_model_update(
|
|
||||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
|
||||||
):
|
|
||||||
"""Test updating internal dataset with new embedding model."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
|
||||||
provider="vendor",
|
|
||||||
indexing_technique="high_quality",
|
|
||||||
embedding_model_provider="openai",
|
|
||||||
embedding_model="text-embedding-ada-002",
|
|
||||||
)
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
# Mock embedding model
|
|
||||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
|
|
||||||
mock_internal_provider_dependencies[
|
|
||||||
"model_manager"
|
|
||||||
].return_value.get_model_instance.return_value = embedding_model
|
|
||||||
|
|
||||||
# Mock collection binding
|
|
||||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
|
|
||||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"embedding_model": "text-embedding-3-small",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify embedding model was validated
|
|
||||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
|
||||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
|
||||||
provider="openai",
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model="text-embedding-3-small",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify collection binding was retrieved
|
|
||||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
|
|
||||||
|
|
||||||
# Verify database update was called with correct data
|
|
||||||
expected_filtered_data = {
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model": "text-embedding-3-small",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"collection_binding_id": "binding-789",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
self._assert_database_update_called(
|
|
||||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify vector index task was triggered
|
|
||||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
|
|
||||||
|
|
||||||
# Verify regenerate summary index task was triggered (when embedding_model changes)
|
|
||||||
mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with(
|
|
||||||
"dataset-123",
|
|
||||||
regenerate_reason="embedding_model_changed",
|
|
||||||
regenerate_vectors_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test updating internal dataset without changing indexing technique."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
|
||||||
provider="vendor",
|
|
||||||
indexing_technique="high_quality",
|
|
||||||
embedding_model_provider="openai",
|
|
||||||
embedding_model="text-embedding-ada-002",
|
|
||||||
collection_binding_id="binding-123",
|
|
||||||
)
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"indexing_technique": "high_quality", # Same as current
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
# Verify database update was called with correct data
|
|
||||||
expected_filtered_data = {
|
|
||||||
"name": "new_name",
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model_provider": "openai",
|
|
||||||
"embedding_model": "text-embedding-ada-002",
|
|
||||||
"collection_binding_id": "binding-123",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
"updated_by": user.id,
|
|
||||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
|
||||||
}
|
|
||||||
|
|
||||||
self._assert_database_update_called(
|
|
||||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify return value
|
|
||||||
assert result == dataset
|
|
||||||
|
|
||||||
# ==================== Error Handling Tests ====================
|
|
||||||
|
|
||||||
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test error when dataset is not found."""
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
update_data = {"name": "new_name"}
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
assert "Dataset not found" in str(context.value)
|
|
||||||
|
|
||||||
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
|
|
||||||
"""Test error when user doesn't have permission."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
|
|
||||||
|
|
||||||
update_data = {"name": "new_name"}
|
|
||||||
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
with pytest.raises(NoPermissionError):
|
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
def test_update_internal_dataset_embedding_model_error(
|
|
||||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
|
||||||
):
|
|
||||||
"""Test error when embedding model is not available."""
|
|
||||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
|
||||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
|
||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
|
||||||
|
|
||||||
# Mock model manager to raise error
|
|
||||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
|
|
||||||
"No Embedding Model available"
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = {
|
|
||||||
"indexing_technique": "high_quality",
|
|
||||||
"embedding_model_provider": "invalid_provider",
|
|
||||||
"embedding_model": "invalid_model",
|
|
||||||
"retrieval_model": "new_model",
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
|
||||||
|
|
||||||
with pytest.raises(Exception) as context:
|
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
|
||||||
|
|
||||||
assert "No Embedding Model available".lower() in str(context.value).lower()
|
|
||||||
Loading…
Reference in New Issue
Block a user