mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
test: migrate metadata partial update tests to testcontainers (#34088)
This commit is contained in:
parent
87a25e326c
commit
e6ab9abf19
@ -0,0 +1,182 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from models.dataset import Dataset, DatasetMetadataBinding, Document
|
||||||
|
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||||
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
|
DocumentMetadataOperation,
|
||||||
|
MetadataDetail,
|
||||||
|
MetadataOperationData,
|
||||||
|
)
|
||||||
|
from services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dataset(db_session, *, tenant_id: str, built_in_field_enabled: bool = False) -> Dataset:
|
||||||
|
dataset = Dataset(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=f"dataset-{uuid4()}",
|
||||||
|
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||||
|
created_by=str(uuid4()),
|
||||||
|
)
|
||||||
|
dataset.id = str(uuid4())
|
||||||
|
dataset.built_in_field_enabled = built_in_field_enabled
|
||||||
|
db_session.add(dataset)
|
||||||
|
db_session.commit()
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _create_document(db_session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None) -> Document:
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
position=1,
|
||||||
|
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||||
|
data_source_info="{}",
|
||||||
|
batch=f"batch-{uuid4()}",
|
||||||
|
name=f"doc-{uuid4()}",
|
||||||
|
created_from=DocumentCreatedFrom.WEB,
|
||||||
|
created_by=str(uuid4()),
|
||||||
|
)
|
||||||
|
document.id = str(uuid4())
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
db_session.add(document)
|
||||||
|
db_session.commit()
|
||||||
|
return document
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataPartialUpdate:
|
||||||
|
@pytest.fixture
|
||||||
|
def tenant_id(self) -> str:
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_id(self) -> str:
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_current_account(self, user_id, tenant_id):
|
||||||
|
account = Mock(id=user_id, current_tenant_id=tenant_id)
|
||||||
|
with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)):
|
||||||
|
yield account
|
||||||
|
|
||||||
|
def test_partial_update_merges_metadata(
|
||||||
|
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
|
||||||
|
):
|
||||||
|
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||||
|
document = _create_document(
|
||||||
|
db_session_with_containers,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
doc_metadata={"existing_key": "existing_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_id = str(uuid4())
|
||||||
|
operation = DocumentMetadataOperation(
|
||||||
|
document_id=document.id,
|
||||||
|
metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")],
|
||||||
|
partial_update=True,
|
||||||
|
)
|
||||||
|
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||||
|
|
||||||
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
updated_doc = db_session_with_containers.get(Document, document.id)
|
||||||
|
assert updated_doc is not None
|
||||||
|
assert updated_doc.doc_metadata["existing_key"] == "existing_value"
|
||||||
|
assert updated_doc.doc_metadata["new_key"] == "new_value"
|
||||||
|
|
||||||
|
def test_full_update_replaces_metadata(
|
||||||
|
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
|
||||||
|
):
|
||||||
|
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||||
|
document = _create_document(
|
||||||
|
db_session_with_containers,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
doc_metadata={"existing_key": "existing_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_id = str(uuid4())
|
||||||
|
operation = DocumentMetadataOperation(
|
||||||
|
document_id=document.id,
|
||||||
|
metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")],
|
||||||
|
partial_update=False,
|
||||||
|
)
|
||||||
|
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||||
|
|
||||||
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
updated_doc = db_session_with_containers.get(Document, document.id)
|
||||||
|
assert updated_doc is not None
|
||||||
|
assert updated_doc.doc_metadata == {"new_key": "new_value"}
|
||||||
|
assert "existing_key" not in updated_doc.doc_metadata
|
||||||
|
|
||||||
|
def test_partial_update_skips_existing_binding(
|
||||||
|
self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account
|
||||||
|
):
|
||||||
|
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||||
|
document = _create_document(
|
||||||
|
db_session_with_containers,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
doc_metadata={"existing_key": "existing_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_id = str(uuid4())
|
||||||
|
existing_binding = DatasetMetadataBinding(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
document_id=document.id,
|
||||||
|
metadata_id=meta_id,
|
||||||
|
created_by=user_id,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(existing_binding)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
operation = DocumentMetadataOperation(
|
||||||
|
document_id=document.id,
|
||||||
|
metadata_list=[MetadataDetail(id=meta_id, name="existing_key", value="existing_value")],
|
||||||
|
partial_update=True,
|
||||||
|
)
|
||||||
|
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||||
|
|
||||||
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
bindings = db_session_with_containers.scalars(
|
||||||
|
select(DatasetMetadataBinding).where(
|
||||||
|
DatasetMetadataBinding.document_id == document.id,
|
||||||
|
DatasetMetadataBinding.metadata_id == meta_id,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
assert len(bindings) == 1
|
||||||
|
|
||||||
|
def test_rollback_called_on_commit_failure(
|
||||||
|
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
|
||||||
|
):
|
||||||
|
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||||
|
document = _create_document(
|
||||||
|
db_session_with_containers,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
doc_metadata={"existing_key": "existing_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_id = str(uuid4())
|
||||||
|
operation = DocumentMetadataOperation(
|
||||||
|
document_id=document.id,
|
||||||
|
metadata_list=[MetadataDetail(id=meta_id, name="key", value="value")],
|
||||||
|
partial_update=True,
|
||||||
|
)
|
||||||
|
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||||
|
|
||||||
|
with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")):
|
||||||
|
with pytest.raises(RuntimeError, match="database connection lost"):
|
||||||
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
@ -1,187 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from models.dataset import Dataset, Document
|
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
|
||||||
DocumentMetadataOperation,
|
|
||||||
MetadataDetail,
|
|
||||||
MetadataOperationData,
|
|
||||||
)
|
|
||||||
from services.metadata_service import MetadataService
|
|
||||||
|
|
||||||
|
|
||||||
class TestMetadataPartialUpdate(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.dataset = MagicMock(spec=Dataset)
|
|
||||||
self.dataset.id = "dataset_id"
|
|
||||||
self.dataset.built_in_field_enabled = False
|
|
||||||
|
|
||||||
self.document = MagicMock(spec=Document)
|
|
||||||
self.document.id = "doc_id"
|
|
||||||
self.document.doc_metadata = {"existing_key": "existing_value"}
|
|
||||||
self.document.data_source_type = "upload_file"
|
|
||||||
|
|
||||||
@patch("services.metadata_service.db")
|
|
||||||
@patch("services.metadata_service.DocumentService")
|
|
||||||
@patch("services.metadata_service.current_account_with_tenant")
|
|
||||||
@patch("services.metadata_service.redis_client")
|
|
||||||
def test_partial_update_merges_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db):
|
|
||||||
# Setup mocks
|
|
||||||
mock_redis.get.return_value = None
|
|
||||||
mock_document_service.get_document.return_value = self.document
|
|
||||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
|
||||||
|
|
||||||
# Mock DB query for existing bindings
|
|
||||||
|
|
||||||
# No existing binding for new key
|
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
||||||
|
|
||||||
# Input data
|
|
||||||
operation = DocumentMetadataOperation(
|
|
||||||
document_id="doc_id",
|
|
||||||
metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")],
|
|
||||||
partial_update=True,
|
|
||||||
)
|
|
||||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
# 1. Check that doc_metadata contains BOTH existing and new keys
|
|
||||||
expected_metadata = {"existing_key": "existing_value", "new_key": "new_value"}
|
|
||||||
assert self.document.doc_metadata == expected_metadata
|
|
||||||
|
|
||||||
# 2. Check that existing bindings were NOT deleted
|
|
||||||
# The delete call in the original code: db.session.query(...).filter_by(...).delete()
|
|
||||||
# In partial update, this should NOT be called.
|
|
||||||
mock_db.session.query.return_value.filter_by.return_value.delete.assert_not_called()
|
|
||||||
|
|
||||||
@patch("services.metadata_service.db")
|
|
||||||
@patch("services.metadata_service.DocumentService")
|
|
||||||
@patch("services.metadata_service.current_account_with_tenant")
|
|
||||||
@patch("services.metadata_service.redis_client")
|
|
||||||
def test_full_update_replaces_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db):
|
|
||||||
# Setup mocks
|
|
||||||
mock_redis.get.return_value = None
|
|
||||||
mock_document_service.get_document.return_value = self.document
|
|
||||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
|
||||||
|
|
||||||
# Input data (partial_update=False by default)
|
|
||||||
operation = DocumentMetadataOperation(
|
|
||||||
document_id="doc_id",
|
|
||||||
metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")],
|
|
||||||
partial_update=False,
|
|
||||||
)
|
|
||||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
# 1. Check that doc_metadata contains ONLY the new key
|
|
||||||
expected_metadata = {"new_key": "new_value"}
|
|
||||||
assert self.document.doc_metadata == expected_metadata
|
|
||||||
|
|
||||||
# 2. Check that existing bindings WERE deleted
|
|
||||||
# In full update (default), we expect the existing bindings to be cleared.
|
|
||||||
mock_db.session.query.return_value.filter_by.return_value.delete.assert_called()
|
|
||||||
|
|
||||||
@patch("services.metadata_service.db")
|
|
||||||
@patch("services.metadata_service.DocumentService")
|
|
||||||
@patch("services.metadata_service.current_account_with_tenant")
|
|
||||||
@patch("services.metadata_service.redis_client")
|
|
||||||
def test_partial_update_skips_existing_binding(
|
|
||||||
self, mock_redis, mock_current_account, mock_document_service, mock_db
|
|
||||||
):
|
|
||||||
# Setup mocks
|
|
||||||
mock_redis.get.return_value = None
|
|
||||||
mock_document_service.get_document.return_value = self.document
|
|
||||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
|
||||||
|
|
||||||
# Mock DB query to return an existing binding
|
|
||||||
# This simulates that the document ALREADY has the metadata we are trying to add
|
|
||||||
mock_existing_binding = MagicMock()
|
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_existing_binding
|
|
||||||
|
|
||||||
# Input data
|
|
||||||
operation = DocumentMetadataOperation(
|
|
||||||
document_id="doc_id",
|
|
||||||
metadata_list=[MetadataDetail(id="existing_meta_id", name="existing_key", value="existing_value")],
|
|
||||||
partial_update=True,
|
|
||||||
)
|
|
||||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
# We verify that db.session.add was NOT called for DatasetMetadataBinding
|
|
||||||
# Since we can't easily check "not called with specific type" on the generic add method without complex logic,
|
|
||||||
# we can check if the number of add calls is 1 (only for the document update) instead of 2 (document + binding)
|
|
||||||
|
|
||||||
# Expected calls:
|
|
||||||
# 1. db.session.add(document)
|
|
||||||
# 2. NO db.session.add(binding) because it exists
|
|
||||||
|
|
||||||
# Note: In the code, db.session.add is called for document.
|
|
||||||
# Then loop over metadata_list.
|
|
||||||
# If existing_binding found, continue.
|
|
||||||
# So binding add should be skipped.
|
|
||||||
|
|
||||||
# Let's filter the calls to add to see what was added
|
|
||||||
add_calls = mock_db.session.add.call_args_list
|
|
||||||
added_objects = [call.args[0] for call in add_calls]
|
|
||||||
|
|
||||||
# Check that no DatasetMetadataBinding was added
|
|
||||||
from models.dataset import DatasetMetadataBinding
|
|
||||||
|
|
||||||
has_binding_add = any(
|
|
||||||
isinstance(obj, DatasetMetadataBinding)
|
|
||||||
or (isinstance(obj, MagicMock) and getattr(obj, "__class__", None) == DatasetMetadataBinding)
|
|
||||||
for obj in added_objects
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since we mock everything, checking isinstance might be tricky if DatasetMetadataBinding
|
|
||||||
# is not the exact class used in the service (imports match).
|
|
||||||
# But we can check the count.
|
|
||||||
# If it were added, there would be 2 calls. If skipped, 1 call.
|
|
||||||
assert mock_db.session.add.call_count == 1
|
|
||||||
|
|
||||||
@patch("services.metadata_service.db")
|
|
||||||
@patch("services.metadata_service.DocumentService")
|
|
||||||
@patch("services.metadata_service.current_account_with_tenant")
|
|
||||||
@patch("services.metadata_service.redis_client")
|
|
||||||
def test_rollback_called_on_commit_failure(self, mock_redis, mock_current_account, mock_document_service, mock_db):
|
|
||||||
"""When db.session.commit() raises, rollback must be called and the exception must propagate."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_redis.get.return_value = None
|
|
||||||
mock_document_service.get_document.return_value = self.document
|
|
||||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
||||||
|
|
||||||
# Make commit raise an exception
|
|
||||||
mock_db.session.commit.side_effect = RuntimeError("database connection lost")
|
|
||||||
|
|
||||||
operation = DocumentMetadataOperation(
|
|
||||||
document_id="doc_id",
|
|
||||||
metadata_list=[MetadataDetail(id="meta_id", name="key", value="value")],
|
|
||||||
partial_update=True,
|
|
||||||
)
|
|
||||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
|
||||||
|
|
||||||
# Act & Assert: the exception must propagate
|
|
||||||
with pytest.raises(RuntimeError, match="database connection lost"):
|
|
||||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
|
||||||
|
|
||||||
# Verify rollback was called
|
|
||||||
mock_db.session.rollback.assert_called_once()
|
|
||||||
|
|
||||||
# Verify the lock key was cleaned up despite the failure
|
|
||||||
mock_redis.delete.assert_called_with("document_metadata_lock_doc_id")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Loading…
Reference in New Issue
Block a user