dify/api/tests/unit_tests/services/test_dataset_service_dataset.py
James 693080aa12
test: migrate dataset service dataset mock tests to testcontainers (#35194)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 19:52:31 +00:00

1296 lines
60 KiB
Python

"""Unit tests for DatasetService and dataset-related collaborators."""
from .dataset_service_test_helpers import (
DatasetNameDuplicateError,
DatasetPermissionEnum,
DatasetPermissionService,
DatasetService,
DatasetServiceUnitDataFactory,
LLMBadRequestError,
MagicMock,
ModelFeature,
ModelType,
NoPermissionError,
PipelineIconInfo,
ProviderTokenNotInitError,
RagPipelineDatasetCreateEntity,
SimpleNamespace,
_make_knowledge_configuration,
_make_retrieval_model,
_make_session_context,
json,
patch,
pytest,
)
class TestDatasetServiceValidation:
"""Unit tests for DatasetService validation helpers."""
@pytest.mark.parametrize(
("dataset_doc_form", "incoming_doc_form"),
[(None, "text_model"), ("text_model", "text_model")],
)
def test_check_doc_form_allows_matching_or_missing_dataset_doc_form(self, dataset_doc_form, incoming_doc_form):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form=dataset_doc_form)
DatasetService.check_doc_form(dataset, incoming_doc_form)
def test_check_doc_form_rejects_mismatched_doc_form(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="qa_model")
with pytest.raises(ValueError, match="doc_form is different"):
DatasetService.check_doc_form(dataset, "text_model")
def test_check_dataset_model_setting_skips_non_high_quality_datasets(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="economy")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_dataset_model_setting(dataset)
model_manager_cls.assert_not_called()
def test_check_dataset_model_setting_validates_embedding_model_for_high_quality_dataset(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_dataset_model_setting(dataset)
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
def test_check_dataset_model_setting_wraps_llm_bad_request_error(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Embedding Model available"):
DatasetService.check_dataset_model_setting(dataset)
def test_check_dataset_model_setting_wraps_provider_token_error(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"):
DatasetService.check_dataset_model_setting(dataset)
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"provider setup"
)
with pytest.raises(ValueError, match="provider setup"):
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
def test_check_reranking_model_setting_uses_rerank_model_type(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="provider",
model_type=ModelType.RERANK,
model="reranker",
)
def test_check_reranking_model_setting_wraps_bad_request(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Rerank Model available"):
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
def test_check_is_multimodal_model_returns_true_when_model_supports_vision(self):
model_schema = SimpleNamespace(features=[ModelFeature.VISION])
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = model_schema
model_instance = SimpleNamespace(
model_type_instance=model_type_instance,
model_name="embedding-model",
credentials={"api_key": "secret"},
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
assert result is True
def test_check_is_multimodal_model_returns_false_when_vision_feature_is_absent(self):
model_schema = SimpleNamespace(features=[])
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = model_schema
model_instance = SimpleNamespace(
model_type_instance=model_type_instance,
model_name="embedding-model",
credentials={"api_key": "secret"},
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
assert result is False
def test_check_is_multimodal_model_raises_when_schema_is_missing(self):
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = None
model_instance = SimpleNamespace(
model_type_instance=model_type_instance,
model_name="embedding-model",
credentials={"api_key": "secret"},
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
with pytest.raises(ValueError, match="Model schema not found"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
def test_check_is_multimodal_model_wraps_bad_request_error(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Model available"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
class TestDatasetServiceCreationAndUpdate:
"""Unit tests for dataset creation and update helpers."""
def test_create_empty_dataset_raises_when_name_already_exists(self):
account = SimpleNamespace(id="user-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalar.return_value = object()
with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"):
DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account)
def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self):
account = SimpleNamespace(id="user-1")
default_embedding_model = SimpleNamespace(provider="provider", model_name="default-embedding")
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch(
"services.dataset_service.Dataset",
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
):
mock_db.session.scalar.return_value = None
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1",
name="Dataset",
description="Description",
indexing_technique="high_quality",
account=account,
)
assert dataset.embedding_model_provider == "provider"
assert dataset.embedding_model == "default-embedding"
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
assert dataset.provider == "vendor"
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.TEXT_EMBEDDING,
)
check_embedding.assert_not_called()
mock_db.session.commit.assert_called_once()
def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset(self):
account = SimpleNamespace(id="user-1")
retrieval_model = _make_retrieval_model()
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch(
"services.dataset_service.Dataset",
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
),
patch(
"services.dataset_service.ExternalKnowledgeBindings",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
) as binding_cls,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()),
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
):
mock_db.session.scalar.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1",
name="External Dataset",
description="Description",
indexing_technique="high_quality",
account=account,
permission=DatasetPermissionEnum.ALL_TEAM,
provider="external",
external_knowledge_api_id="api-1",
external_knowledge_id="knowledge-1",
embedding_model_provider="provider",
embedding_model_name="embedding-model",
retrieval_model=retrieval_model,
summary_index_setting={"enable": True},
)
assert dataset.embedding_model_provider == "provider"
assert dataset.embedding_model == "embedding-model"
assert dataset.retrieval_model == retrieval_model.model_dump()
assert dataset.summary_index_setting == {"enable": True}
check_embedding.assert_called_once_with("tenant-1", "provider", "embedding-model")
check_reranking.assert_called_once_with("tenant-1", "rerank-provider", "rerank-model")
binding_cls.assert_called_once_with(
tenant_id="tenant-1",
dataset_id="dataset-1",
external_knowledge_api_id="api-1",
external_knowledge_id="knowledge-1",
created_by="user-1",
)
assert mock_db.session.add.call_count == 2
mock_db.session.commit.assert_called_once()
def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self):
entity = RagPipelineDatasetCreateEntity(
name="Existing Dataset",
description="Description",
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
permission=DatasetPermissionEnum.ALL_TEAM,
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalar.return_value = object()
with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"):
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self):
entity = RagPipelineDatasetCreateEntity(
name="",
description="Description",
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
permission=DatasetPermissionEnum.ALL_TEAM,
)
pipeline = SimpleNamespace(id="pipeline-1")
def pipeline_factory(**kwargs):
pipeline.__dict__.update(kwargs)
return pipeline
def dataset_factory(**kwargs):
return SimpleNamespace(id="dataset-1", **kwargs)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name,
patch("services.dataset_service.Pipeline", side_effect=pipeline_factory),
patch("services.dataset_service.Dataset", side_effect=dataset_factory),
):
mock_db.session.scalars.return_value.all.return_value = [
SimpleNamespace(name="Untitled"),
SimpleNamespace(name="Untitled 1"),
]
dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
assert entity.name == "Untitled 2"
assert dataset.pipeline_id == "pipeline-1"
assert dataset.runtime_mode == "rag_pipeline"
generate_name.assert_called_once_with(["Untitled", "Untitled 1"], "Untitled")
mock_db.session.commit.assert_called_once()
def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self):
entity = RagPipelineDatasetCreateEntity(
name="Dataset",
description="Description",
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
permission=DatasetPermissionEnum.ALL_TEAM,
)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
):
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
def test_update_dataset_raises_when_dataset_is_missing(self):
with patch.object(DatasetService, "get_dataset", return_value=None):
with pytest.raises(ValueError, match="Dataset not found"):
DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1"))
def test_update_dataset_raises_when_new_name_conflicts(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.name = "Old Dataset"
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "_has_dataset_same_name", return_value=True),
):
with pytest.raises(ValueError, match="Dataset name already exists"):
DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1"))
def test_update_dataset_routes_external_datasets_to_external_helper(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.provider = "external"
user = DatasetServiceUnitDataFactory.create_user_mock()
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission") as check_permission,
patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external,
):
result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user)
assert result == "updated"
check_permission.assert_called_once_with(dataset, user)
update_external.assert_called_once_with(dataset, {"name": dataset.name}, user)
def test_update_dataset_routes_internal_datasets_to_internal_helper(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.provider = "vendor"
user = DatasetServiceUnitDataFactory.create_user_mock()
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission") as check_permission,
patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal,
):
result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user)
assert result == "updated"
check_permission.assert_called_once_with(dataset, user)
update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user)
def test_has_dataset_same_name_returns_true_when_query_matches(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalar.return_value = object()
result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset")
assert result is True
def test_update_external_dataset_updates_dataset_and_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
user = SimpleNamespace(id="user-1")
now = object()
with (
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
patch(
"services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()
) as get_external_knowledge_api,
patch("services.dataset_service.naive_utc_now", return_value=now),
patch("services.dataset_service.db") as mock_db,
):
result = DatasetService._update_external_dataset(
dataset,
{
"external_retrieval_model": {"top_k": 3},
"summary_index_setting": {"enable": True},
"name": "Updated Dataset",
"description": "Updated description",
"permission": DatasetPermissionEnum.PARTIAL_TEAM,
"external_knowledge_id": "knowledge-1",
"external_knowledge_api_id": "api-1",
},
user,
)
assert result is dataset
assert dataset.retrieval_model == {"top_k": 3}
assert dataset.summary_index_setting == {"enable": True}
assert dataset.name == "Updated Dataset"
assert dataset.description == "Updated description"
assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM
assert dataset.updated_by == "user-1"
assert dataset.updated_at is now
get_external_knowledge_api.assert_called_once_with("api-1", dataset.tenant_id)
update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1")
mock_db.session.add.assert_called_once_with(dataset)
mock_db.session.commit.assert_called_once()
@pytest.mark.parametrize(
("payload", "message"),
[
({"external_knowledge_api_id": "api-1"}, "External knowledge id is required"),
({"external_knowledge_id": "knowledge-1"}, "External knowledge api id is required"),
],
)
def test_update_external_dataset_requires_external_binding_fields(self, payload, message):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
with pytest.raises(ValueError, match=message):
DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1"))
def test_update_external_dataset_rejects_cross_tenant_external_api_id(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
with (
patch(
"services.dataset_service.ExternalDatasetService.get_external_knowledge_api",
side_effect=ValueError("api template not found"),
) as get_external_knowledge_api,
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
patch("services.dataset_service.db") as mock_db,
):
with pytest.raises(ValueError, match="api template not found"):
DatasetService._update_external_dataset(
dataset,
{
"external_knowledge_id": "knowledge-1",
"external_knowledge_api_id": "foreign-api",
},
SimpleNamespace(id="user-1"),
)
get_external_knowledge_api.assert_called_once_with("foreign-api", dataset.tenant_id)
update_binding.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock()
session.scalar.return_value = binding
session.add = MagicMock()
session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = session_context
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
):
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
assert binding.external_knowledge_id == "new-knowledge"
assert binding.external_knowledge_api_id == "new-api"
session.add.assert_called_once_with(binding)
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
session = MagicMock()
session.scalar.return_value = None
session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value = session_context
with (
patch("services.dataset_service.db"),
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
):
with pytest.raises(ValueError, match="External knowledge binding not found"):
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")
def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
user = SimpleNamespace(id="user-1")
now = object()
update_payload = {
"name": "Updated Dataset",
"description": None,
"partial_member_list": [{"user_id": "member-1"}],
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
"external_retrieval_model": {"top_k": 2},
"retrieval_model": {"top_k": 4},
"summary_index_setting": {"enable": True},
"icon_info": {"icon": "book"},
}
with (
patch.object(DatasetService, "_handle_indexing_technique_change", return_value="update"),
patch.object(DatasetService, "_update_pipeline_knowledge_base_node_data") as update_pipeline,
patch("services.dataset_service.naive_utc_now", return_value=now),
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task,
patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task,
):
result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user)
assert result is dataset
updated_values = mock_db.session.execute.call_args.args[0].compile().params
assert updated_values["name"] == "Updated Dataset"
assert updated_values["description"] is None
assert updated_values["retrieval_model"] == {"top_k": 4}
assert updated_values["summary_index_setting"] == {"enable": True}
assert updated_values["icon_info"] == {"icon": "book"}
assert updated_values["updated_by"] == "user-1"
assert updated_values["updated_at"] is now
assert "partial_member_list" not in updated_values
assert "external_knowledge_api_id" not in updated_values
assert "external_knowledge_id" not in updated_values
assert "external_retrieval_model" not in updated_values
mock_db.session.commit.assert_called_once()
mock_db.session.refresh.assert_called_once_with(dataset)
update_pipeline.assert_called_once_with(dataset, "user-1")
vector_task.delay.assert_called_once_with("dataset-1", "update")
regenerate_task.delay.assert_called_once_with(
"dataset-1",
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline_dataset(self):
dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1")
with patch("services.dataset_service.db") as mock_db:
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.get.assert_not_called()
def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self):
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.get.return_value = None
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.commit.assert_not_called()
def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_workflows(self):
dataset = SimpleNamespace(
id="dataset-1",
runtime_mode="rag_pipeline",
pipeline_id="pipeline-1",
embedding_model="embedding-model",
embedding_model_provider="provider",
retrieval_model={"top_k": 5},
chunk_structure="paragraph",
indexing_technique="high_quality",
keyword_number=8,
summary_index_setting={"enable": True},
)
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1")
published_workflow = SimpleNamespace(
graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}, {"data": {"type": "start"}}]}),
type="chat",
features={"feature": True},
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
draft_workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}]}))
new_workflow = SimpleNamespace(id="workflow-1")
rag_pipeline_service = MagicMock()
rag_pipeline_service.get_published_workflow.return_value = published_workflow
rag_pipeline_service.get_draft_workflow.return_value = draft_workflow
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new,
):
mock_db.session.get.return_value = pipeline
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
published_graph = json.loads(workflow_new.call_args.kwargs["graph"])
assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model"
assert published_graph["nodes"][0]["data"]["summary_index_setting"] == {"enable": True}
assert json.loads(draft_workflow.graph)["nodes"][0]["data"]["embedding_model_provider"] == "provider"
mock_db.session.add.assert_any_call(new_workflow)
mock_db.session.add.assert_any_call(draft_workflow)
mock_db.session.commit.assert_called_once()
def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(self):
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1")
rag_pipeline_service = MagicMock()
rag_pipeline_service.get_published_workflow.side_effect = RuntimeError("boom")
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
):
mock_db.session.get.return_value = pipeline
with pytest.raises(RuntimeError, match="boom"):
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.rollback.assert_called_once()
def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="economy")
result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data)
assert result is None
assert filtered_data == {}
def test_handle_indexing_technique_change_switches_to_economy(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="high_quality")
result = DatasetService._handle_indexing_technique_change(
dataset,
{"indexing_technique": "economy"},
filtered_data,
)
assert result == "remove"
assert filtered_data == {
"embedding_model": None,
"embedding_model_provider": None,
"collection_binding_id": None,
}
def test_handle_indexing_technique_change_switches_to_high_quality(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="economy")
with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding:
result = DatasetService._handle_indexing_technique_change(
dataset,
{"indexing_technique": "high_quality"},
filtered_data,
)
assert result == "add"
configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data)
def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="high_quality")
with patch.object(
DatasetService,
"_handle_embedding_model_update_when_technique_unchanged",
return_value="update",
) as update_embedding:
result = DatasetService._handle_indexing_technique_change(
dataset,
{"indexing_technique": "high_quality"},
filtered_data,
)
assert result == "update"
update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data)
def test_configure_embedding_model_for_high_quality_updates_filtered_data(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
filtered_data: dict[str, object] = {}
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-1"),
),
):
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService._configure_embedding_model_for_high_quality(
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
filtered_data,
)
assert filtered_data == {
"embedding_model": "embedding-model",
"embedding_model_provider": "provider",
"collection_binding_id": "binding-1",
}
@pytest.mark.parametrize(
("error", "message"),
[
(LLMBadRequestError(), "No Embedding Model available"),
(ProviderTokenNotInitError("provider setup"), "provider setup"),
],
)
def test_configure_embedding_model_for_high_quality_wraps_model_errors(self, error, message):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
):
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error
with pytest.raises(ValueError, match=message):
DatasetService._configure_embedding_model_for_high_quality(
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
{},
)
def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
filtered_data: dict[str, object] = {}
with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings:
result = DatasetService._handle_embedding_model_update_when_technique_unchanged(
dataset,
{},
filtered_data,
)
assert result is None
preserve_settings.assert_called_once_with(dataset, filtered_data)
def test_handle_embedding_model_update_when_technique_unchanged_updates_when_model_is_provided(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings:
result = DatasetService._handle_embedding_model_update_when_technique_unchanged(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
{},
)
assert result == "update"
update_settings.assert_called_once()
def test_preserve_existing_embedding_settings_keeps_current_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
collection_binding_id="binding-1",
)
filtered_data = {"embedding_model_provider": "", "embedding_model": ""}
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
assert filtered_data == {
"embedding_model_provider": "provider",
"embedding_model": "embedding-model",
"collection_binding_id": "binding-1",
}
def test_preserve_existing_embedding_settings_removes_empty_placeholders_without_existing_values(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider=None,
embedding_model=None,
collection_binding_id=None,
)
filtered_data = {"embedding_model_provider": "", "embedding_model": ""}
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
assert filtered_data == {}
def test_update_embedding_model_settings_returns_update_for_changed_values(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings:
result = DatasetService._update_embedding_model_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
{},
)
assert result == "update"
apply_settings.assert_called_once()
def test_update_embedding_model_settings_returns_none_for_unchanged_values(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
result = DatasetService._update_embedding_model_settings(
dataset,
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
{},
)
assert result is None
def test_update_embedding_model_settings_wraps_bad_request_errors(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()):
with pytest.raises(ValueError, match="No Embedding Model available"):
DatasetService._update_embedding_model_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
{},
)
def test_apply_new_embedding_settings_updates_binding_for_new_model(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1")
filtered_data: dict[str, object] = {}
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-2"),
),
):
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two",
model_name="embedding-model-two",
)
DatasetService._apply_new_embedding_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
filtered_data,
)
assert filtered_data == {
"embedding_model": "embedding-model-two",
"embedding_model_provider": "provider-two",
"collection_binding_id": "binding-2",
}
def test_apply_new_embedding_settings_preserves_existing_values_when_provider_token_is_missing(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
collection_binding_id="binding-1",
)
filtered_data: dict[str, object] = {}
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
):
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
DatasetService._apply_new_embedding_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
filtered_data,
)
assert filtered_data == {
"embedding_model_provider": "provider",
"embedding_model": "embedding-model",
"collection_binding_id": "binding-1",
}
@pytest.mark.parametrize(
("summary_index_setting", "expected"),
[
(None, False),
({"enable": False}, False),
({"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, False),
({"enable": True, "model_name": "new-model", "model_provider_name": "provider-two"}, True),
],
)
def test_check_summary_index_setting_model_changed(self, summary_index_setting, expected):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
dataset_id="dataset-1",
summary_index_setting={"enable": True, "model_name": "old-model", "model_provider_name": "provider"},
)
result = DatasetService._check_summary_index_setting_model_changed(
dataset,
{"summary_index_setting": summary_index_setting} if summary_index_setting is not None else {},
)
assert result is expected
class TestDatasetServiceRagPipelineSettings:
"""Unit tests for rag-pipeline dataset setting updates."""
def test_update_rag_pipeline_dataset_settings_requires_current_tenant(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
knowledge_configuration = _make_knowledge_configuration()
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id=None)):
with pytest.raises(ValueError, match="Current user or current tenant not found"):
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
def test_update_rag_pipeline_dataset_settings_without_published_high_quality_updates_embedding_settings(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(summary_index_setting={"enable": True})
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_is_multimodal_model", return_value=True) as check_multimodal,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-1"),
),
):
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
assert dataset.chunk_structure == "paragraph"
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "embedding-model"
assert dataset.embedding_model_provider == "provider"
assert dataset.collection_binding_id == "binding-1"
assert dataset.is_multimodal is True
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
assert dataset.summary_index_setting == {"enable": True}
check_multimodal.assert_called_once_with("tenant-1", "provider", "embedding-model")
session.add.assert_called_once_with(dataset)
session.commit.assert_not_called()
def test_update_rag_pipeline_dataset_settings_without_published_economy_updates_keyword_number(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
indexing_technique="economy",
embedding_model_provider="",
embedding_model="",
keyword_number=12,
)
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
assert dataset.indexing_technique == "economy"
assert dataset.keyword_number == 12
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
def test_update_rag_pipeline_dataset_settings_with_published_rejects_chunk_structure_changes(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(chunk_structure="sentence")
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"):
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
def test_update_rag_pipeline_dataset_settings_with_published_rejects_switch_to_economy(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "high_quality"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
indexing_technique="economy",
embedding_model_provider="",
embedding_model="",
)
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
with pytest.raises(
ValueError,
match="Knowledge base indexing technique is not allowed to be updated to economy",
):
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
def test_update_rag_pipeline_dataset_settings_with_published_adds_high_quality_index(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "economy"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration()
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_is_multimodal_model", return_value=False),
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-1"),
),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "embedding-model"
assert dataset.embedding_model_provider == "provider"
assert dataset.collection_binding_id == "binding-1"
assert dataset.is_multimodal is False
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_called_once_with("dataset-1", "add")
def test_update_rag_pipeline_dataset_settings_with_published_updates_changed_embedding_model(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "provider"
dataset.embedding_model = "embedding-model"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
embedding_model_provider="provider-two",
embedding_model="embedding-model-two",
summary_index_setting={"enable": True},
)
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_is_multimodal_model", return_value=True),
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-2"),
),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two",
model_name="embedding-model-two",
)
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.embedding_model_provider == "provider-two"
assert dataset.embedding_model == "embedding-model-two"
assert dataset.collection_binding_id == "binding-2"
assert dataset.is_multimodal is True
assert dataset.summary_index_setting == {"enable": True}
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_called_once_with("dataset-1", "update")
def test_update_rag_pipeline_dataset_settings_with_published_skips_embedding_update_when_token_is_missing(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "provider"
dataset.embedding_model = "embedding-model"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
embedding_model_provider="provider-two",
embedding_model="embedding-model-two",
)
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.embedding_model_provider == "provider"
assert dataset.embedding_model == "embedding-model"
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_called_once_with("dataset-1", "update")
def test_update_rag_pipeline_dataset_settings_with_published_updates_economy_keyword_number(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "economy"
dataset.keyword_number = 5
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
indexing_technique="economy",
embedding_model_provider="",
embedding_model="",
keyword_number=9,
)
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.keyword_number == 9
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_not_called()
class TestDatasetServicePermissionsAndLifecycle:
"""Unit tests for dataset permissions, deletion, and metadata helpers."""
def test_check_dataset_operator_permission_validates_required_arguments(self):
with pytest.raises(ValueError, match="Dataset not found"):
DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None)
with pytest.raises(ValueError, match="User not found"):
DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1"))
class TestDatasetCollectionBindingService:
"""Unit tests for dataset collection binding lookups and creation."""
class TestDatasetPermissionService:
"""Unit tests for dataset partial-member management helpers."""
def test_update_partial_member_list_rolls_back_on_exception(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.add_all.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
DatasetPermissionService.update_partial_member_list(
"tenant-1",
"dataset-1",
[{"user_id": "user-1"}],
)
mock_db.session.rollback.assert_called_once()
def test_check_permission_requires_dataset_editor(self):
user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
with pytest.raises(NoPermissionError, match="does not have permission"):
DatasetPermissionService.check_permission(user, dataset, "all_team", [])
def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team")
with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(user, dataset, "only_me", [])
def test_check_permission_requires_partial_member_list_for_partial_members_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members")
with pytest.raises(ValueError, match="Partial member list is required"):
DatasetPermissionService.check_permission(user, dataset, "partial_members", [])
def test_check_permission_rejects_dataset_operator_member_list_changes(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
dataset_id="dataset-1", permission="partial_members"
)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
with pytest.raises(ValueError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(
user,
dataset,
"partial_members",
[{"user_id": "user-2"}],
)
def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
dataset_id="dataset-1", permission="partial_members"
)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
DatasetPermissionService.check_permission(
user,
dataset,
"partial_members",
[{"user_id": "user-1"}],
)
def test_clear_partial_member_list_rolls_back_on_exception(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.execute.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
DatasetPermissionService.clear_partial_member_list("dataset-1")
mock_db.session.rollback.assert_called_once()