feat: replace db.session with db_session_with_containers (#32942)

This commit is contained in:
Renzo 2026-03-03 21:50:41 -08:00 committed by GitHub
parent 2f4c740d46
commit ad000c42b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 3078 additions and 2669 deletions

View File

@ -9,8 +9,8 @@ from itertools import starmap
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@ -28,6 +28,7 @@ class DatasetCollectionBindingTestDataFactory:
@staticmethod @staticmethod
def create_collection_binding( def create_collection_binding(
db_session_with_containers: Session,
provider_name: str = "openai", provider_name: str = "openai",
model_name: str = "text-embedding-ada-002", model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc", collection_name: str = "collection-abc",
@ -51,8 +52,8 @@ class DatasetCollectionBindingTestDataFactory:
collection_name=collection_name, collection_name=collection_name,
type=collection_type, type=collection_type,
) )
db.session.add(binding) db_session_with_containers.add(binding)
db.session.commit() db_session_with_containers.commit()
return binding return binding
@ -64,7 +65,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
including various provider/model combinations, collection types, and edge cases. including various provider/model combinations, collection types, and edge cases.
""" """
def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers): def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session):
""" """
Test successful retrieval of an existing collection binding. Test successful retrieval of an existing collection binding.
@ -77,6 +78,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
model_name = "text-embedding-ada-002" model_name = "text-embedding-ada-002"
collection_type = "dataset" collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name=provider_name, provider_name=provider_name,
model_name=model_name, model_name=model_name,
collection_name="existing-collection", collection_name="existing-collection",
@ -92,7 +94,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.id == existing_binding.id assert result.id == existing_binding.id
assert result.collection_name == "existing-collection" assert result.collection_name == "existing-collection"
def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers): def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session):
""" """
Test successful creation of a new collection binding when none exists. Test successful creation of a new collection binding when none exists.
@ -116,7 +118,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.type == collection_type assert result.type == collection_type
assert result.collection_name is not None assert result.collection_name is not None
def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers): def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session):
"""Test get_dataset_collection_binding with different collection type.""" """Test get_dataset_collection_binding with different collection type."""
# Arrange # Arrange
provider_name = "openai" provider_name = "openai"
@ -133,7 +135,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.provider_name == provider_name assert result.provider_name == provider_name
assert result.model_name == model_name assert result.model_name == model_name
def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers): def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session):
"""Test get_dataset_collection_binding with default collection type parameter.""" """Test get_dataset_collection_binding with default collection type parameter."""
# Arrange # Arrange
provider_name = "openai" provider_name = "openai"
@ -147,7 +149,9 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.provider_name == provider_name assert result.provider_name == provider_name
assert result.model_name == model_name assert result.model_name == model_name
def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers): def test_get_dataset_collection_binding_different_provider_model_combination(
self, db_session_with_containers: Session
):
"""Test get_dataset_collection_binding with various provider/model combinations.""" """Test get_dataset_collection_binding with various provider/model combinations."""
# Arrange # Arrange
combinations = [ combinations = [
@ -174,10 +178,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
including successful retrieval and error handling for missing bindings. including successful retrieval and error handling for missing bindings.
""" """
def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers): def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session):
"""Test successful retrieval of collection binding by ID and type.""" """Test successful retrieval of collection binding by ID and type."""
# Arrange # Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai", provider_name="openai",
model_name="text-embedding-ada-002", model_name="text-embedding-ada-002",
collection_name="test-collection", collection_name="test-collection",
@ -194,7 +199,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
assert result.collection_name == "test-collection" assert result.collection_name == "test-collection"
assert result.type == "dataset" assert result.type == "dataset"
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers): def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
"""Test error handling when collection binding is not found by ID and type.""" """Test error handling when collection binding is not found by ID and type."""
# Arrange # Arrange
non_existent_id = str(uuid4()) non_existent_id = str(uuid4())
@ -203,10 +208,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
with pytest.raises(ValueError, match="Dataset collection binding not found"): with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers): def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
self, db_session_with_containers: Session
):
"""Test retrieval by ID and type with different collection type.""" """Test retrieval by ID and type with different collection type."""
# Arrange # Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai", provider_name="openai",
model_name="text-embedding-ada-002", model_name="text-embedding-ada-002",
collection_name="test-collection", collection_name="test-collection",
@ -222,10 +230,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
assert result.id == binding.id assert result.id == binding.id
assert result.type == "custom_type" assert result.type == "custom_type"
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers): def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(
self, db_session_with_containers: Session
):
"""Test retrieval by ID with default collection type.""" """Test retrieval by ID with default collection type."""
# Arrange # Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai", provider_name="openai",
model_name="text-embedding-ada-002", model_name="text-embedding-ada-002",
collection_name="test-collection", collection_name="test-collection",
@ -239,10 +250,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
assert result.id == binding.id assert result.id == binding.id
assert result.type == "dataset" assert result.type == "dataset"
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers): def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
"""Test error when binding exists but with wrong collection type.""" """Test error when binding exists but with wrong collection type."""
# Arrange # Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai", provider_name="openai",
model_name="text-embedding-ada-002", model_name="text-embedding-ada-002",
collection_name="test-collection", collection_name="test-collection",

View File

@ -10,9 +10,9 @@ from unittest.mock import patch
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.model import App from models.model import App
@ -27,6 +27,7 @@ class DatasetUpdateDeleteTestDataFactory:
@staticmethod @staticmethod
def create_account_with_tenant( def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.NORMAL, role: TenantAccountRole = TenantAccountRole.NORMAL,
tenant: Tenant | None = None, tenant: Tenant | None = None,
) -> tuple[Account, Tenant]: ) -> tuple[Account, Tenant]:
@ -37,13 +38,13 @@ class DatasetUpdateDeleteTestDataFactory:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
if tenant is None: if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
join = TenantAccountJoin( join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -51,14 +52,15 @@ class DatasetUpdateDeleteTestDataFactory:
role=role, role=role,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
@staticmethod @staticmethod
def create_dataset( def create_dataset(
db_session_with_containers: Session,
tenant_id: str, tenant_id: str,
created_by: str, created_by: str,
name: str = "Test Dataset", name: str = "Test Dataset",
@ -78,12 +80,12 @@ class DatasetUpdateDeleteTestDataFactory:
retrieval_model={"top_k": 2}, retrieval_model={"top_k": 2},
enable_api=enable_api, enable_api=enable_api,
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
@staticmethod @staticmethod
def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App: def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App:
"""Create a real app for AppDatasetJoin.""" """Create a real app for AppDatasetJoin."""
app = App( app = App(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -96,16 +98,16 @@ class DatasetUpdateDeleteTestDataFactory:
enable_api=True, enable_api=True,
created_by=created_by, created_by=created_by,
) )
db.session.add(app) db_session_with_containers.add(app)
db.session.commit() db_session_with_containers.commit()
return app return app
@staticmethod @staticmethod
def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin: def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin:
"""Create a real AppDatasetJoin record.""" """Create a real AppDatasetJoin record."""
join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id) join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id)
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
return join return join
@ -114,7 +116,7 @@ class TestDatasetServiceDeleteDataset:
Comprehensive integration tests for DatasetService.delete_dataset method. Comprehensive integration tests for DatasetService.delete_dataset method.
""" """
def test_delete_dataset_success(self, db_session_with_containers): def test_delete_dataset_success(self, db_session_with_containers: Session):
""" """
Test successful deletion of a dataset. Test successful deletion of a dataset.
@ -130,8 +132,10 @@ class TestDatasetServiceDeleteDataset:
- Method returns True - Method returns True
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
# Act # Act
with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted: with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted:
@ -139,10 +143,10 @@ class TestDatasetServiceDeleteDataset:
# Assert # Assert
assert result is True assert result is True
assert db.session.get(Dataset, dataset.id) is None assert db_session_with_containers.get(Dataset, dataset.id) is None
mock_dataset_was_deleted.send.assert_called_once_with(dataset) mock_dataset_was_deleted.send.assert_called_once_with(dataset)
def test_delete_dataset_not_found(self, db_session_with_containers): def test_delete_dataset_not_found(self, db_session_with_containers: Session):
""" """
Test handling when dataset is not found. Test handling when dataset is not found.
@ -156,7 +160,9 @@ class TestDatasetServiceDeleteDataset:
- No database operations are performed - No database operations are performed
""" """
# Arrange # Arrange
owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset_id = str(uuid4()) dataset_id = str(uuid4())
# Act # Act
@ -165,7 +171,7 @@ class TestDatasetServiceDeleteDataset:
# Assert # Assert
assert result is False assert result is False
def test_delete_dataset_permission_denied_error(self, db_session_with_containers): def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session):
""" """
Test error handling when user lacks permission. Test error handling when user lacks permission.
@ -178,19 +184,22 @@ class TestDatasetServiceDeleteDataset:
- No database operations are performed - No database operations are performed
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.NORMAL, role=TenantAccountRole.NORMAL,
tenant=tenant, tenant=tenant,
) )
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
# Act & Assert # Act & Assert
with pytest.raises(NoPermissionError): with pytest.raises(NoPermissionError):
DatasetService.delete_dataset(dataset.id, normal_user) DatasetService.delete_dataset(dataset.id, normal_user)
# Verify no deletion was attempted # Verify no deletion was attempted
assert db.session.get(Dataset, dataset.id) is not None assert db_session_with_containers.get(Dataset, dataset.id) is not None
class TestDatasetServiceDatasetUseCheck: class TestDatasetServiceDatasetUseCheck:
@ -198,7 +207,7 @@ class TestDatasetServiceDatasetUseCheck:
Comprehensive integration tests for DatasetService.dataset_use_check method. Comprehensive integration tests for DatasetService.dataset_use_check method.
""" """
def test_dataset_use_check_in_use(self, db_session_with_containers): def test_dataset_use_check_in_use(self, db_session_with_containers: Session):
""" """
Test detection when dataset is in use. Test detection when dataset is in use.
@ -211,10 +220,12 @@ class TestDatasetServiceDatasetUseCheck:
- Database query is executed - Database query is executed
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) db_session_with_containers, role=TenantAccountRole.OWNER
app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id) )
DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id) dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id)
DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id)
# Act # Act
result = DatasetService.dataset_use_check(dataset.id) result = DatasetService.dataset_use_check(dataset.id)
@ -222,7 +233,7 @@ class TestDatasetServiceDatasetUseCheck:
# Assert # Assert
assert result is True assert result is True
def test_dataset_use_check_not_in_use(self, db_session_with_containers): def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session):
""" """
Test detection when dataset is not in use. Test detection when dataset is not in use.
@ -235,8 +246,10 @@ class TestDatasetServiceDatasetUseCheck:
- Database query is executed - Database query is executed
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
# Act # Act
result = DatasetService.dataset_use_check(dataset.id) result = DatasetService.dataset_use_check(dataset.id)
@ -250,7 +263,7 @@ class TestDatasetServiceUpdateDatasetApiStatus:
Comprehensive integration tests for DatasetService.update_dataset_api_status method. Comprehensive integration tests for DatasetService.update_dataset_api_status method.
""" """
def test_update_dataset_api_status_enable_success(self, db_session_with_containers): def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session):
""" """
Test successful enabling of dataset API access. Test successful enabling of dataset API access.
@ -264,8 +277,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
- Transaction is committed - Transaction is committed
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
db_session_with_containers, tenant.id, owner.id, enable_api=False
)
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
# Act # Act
@ -276,12 +293,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
DatasetService.update_dataset_api_status(dataset.id, True) DatasetService.update_dataset_api_status(dataset.id, True)
# Assert # Assert
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.enable_api is True assert dataset.enable_api is True
assert dataset.updated_by == owner.id assert dataset.updated_by == owner.id
assert dataset.updated_at == current_time assert dataset.updated_at == current_time
def test_update_dataset_api_status_disable_success(self, db_session_with_containers): def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session):
""" """
Test successful disabling of dataset API access. Test successful disabling of dataset API access.
@ -295,8 +312,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
- Transaction is committed - Transaction is committed
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True) db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
db_session_with_containers, tenant.id, owner.id, enable_api=True
)
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
# Act # Act
@ -307,11 +328,11 @@ class TestDatasetServiceUpdateDatasetApiStatus:
DatasetService.update_dataset_api_status(dataset.id, False) DatasetService.update_dataset_api_status(dataset.id, False)
# Assert # Assert
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.enable_api is False assert dataset.enable_api is False
assert dataset.updated_by == owner.id assert dataset.updated_by == owner.id
def test_update_dataset_api_status_not_found_error(self, db_session_with_containers): def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session):
""" """
Test error handling when dataset is not found. Test error handling when dataset is not found.
@ -330,7 +351,7 @@ class TestDatasetServiceUpdateDatasetApiStatus:
with pytest.raises(NotFound, match="Dataset not found"): with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(dataset_id, True) DatasetService.update_dataset_api_status(dataset_id, True)
def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers): def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session):
""" """
Test error handling when current_user is missing. Test error handling when current_user is missing.
@ -343,8 +364,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
- No updates are committed - No updates are committed
""" """
# Arrange # Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
db_session_with_containers, tenant.id, owner.id, enable_api=False
)
# Act & Assert # Act & Assert
with ( with (
@ -354,6 +379,6 @@ class TestDatasetServiceUpdateDatasetApiStatus:
DatasetService.update_dataset_api_status(dataset.id, True) DatasetService.update_dataset_api_status(dataset.id, True)
# Verify no commit was attempted # Verify no commit was attempted
db.session.rollback() db_session_with_containers.rollback()
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.enable_api is False assert dataset.enable_api is False

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account from models import Account
@ -87,7 +88,7 @@ class TestAgentService:
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -133,13 +134,12 @@ class TestAgentService:
# Update the app model config to set agent_mode for agent-chat mode # Update the app model config to set agent_mode for agent-chat mode
if app.mode == "agent-chat" and app.app_model_config: if app.mode == "agent-chat" and app.app_model_config:
app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []}) app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []})
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
return app, account return app, account
def _create_test_conversation_and_message(self, db_session_with_containers, app, account): def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account):
""" """
Helper method to create a test conversation and message with agent thoughts. Helper method to create a test conversation and message with agent thoughts.
@ -153,8 +153,6 @@ class TestAgentService:
""" """
fake = Faker() fake = Faker()
from extensions.ext_database import db
# Create conversation # Create conversation
conversation = Conversation( conversation = Conversation(
id=fake.uuid4(), id=fake.uuid4(),
@ -167,8 +165,8 @@ class TestAgentService:
mode="chat", mode="chat",
from_source="api", from_source="api",
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.commit() db_session_with_containers.commit()
# Create app model config # Create app model config
app_model_config = AppModelConfig( app_model_config = AppModelConfig(
@ -180,12 +178,12 @@ class TestAgentService:
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
) )
app_model_config.id = fake.uuid4() app_model_config.id = fake.uuid4()
db.session.add(app_model_config) db_session_with_containers.add(app_model_config)
db.session.commit() db_session_with_containers.commit()
# Update conversation with app model config # Update conversation with app model config
conversation.app_model_config_id = app_model_config.id conversation.app_model_config_id = app_model_config.id
db.session.commit() db_session_with_containers.commit()
# Create message # Create message
message = Message( message = Message(
@ -206,12 +204,12 @@ class TestAgentService:
currency="USD", currency="USD",
from_source="api", from_source="api",
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
return conversation, message return conversation, message
def _create_test_agent_thoughts(self, db_session_with_containers, message): def _create_test_agent_thoughts(self, db_session_with_containers: Session, message):
""" """
Helper method to create test agent thoughts for a message. Helper method to create test agent thoughts for a message.
@ -224,8 +222,6 @@ class TestAgentService:
""" """
fake = Faker() fake = Faker()
from extensions.ext_database import db
agent_thoughts = [] agent_thoughts = []
# Create first agent thought # Create first agent thought
@ -251,7 +247,7 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(thought1) db_session_with_containers.add(thought1)
agent_thoughts.append(thought1) agent_thoughts.append(thought1)
# Create second agent thought # Create second agent thought
@ -277,14 +273,14 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(thought2) db_session_with_containers.add(thought2)
agent_thoughts.append(thought2) agent_thoughts.append(thought2)
db.session.commit() db_session_with_containers.commit()
return agent_thoughts return agent_thoughts
def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of agent logs with complete data. Test successful retrieval of agent logs with complete data.
""" """
@ -344,7 +340,7 @@ class TestAgentService:
assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon
def test_get_agent_logs_conversation_not_found( def test_get_agent_logs_conversation_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when conversation is not found. Test error handling when conversation is not found.
@ -358,7 +354,9 @@ class TestAgentService:
with pytest.raises(ValueError, match="Conversation not found"): with pytest.raises(ValueError, match="Conversation not found"):
AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4()) AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4())
def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_message_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test error handling when message is not found. Test error handling when message is not found.
""" """
@ -372,7 +370,9 @@ class TestAgentService:
with pytest.raises(ValueError, match="Message not found"): with pytest.raises(ValueError, match="Message not found"):
AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4()) AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4())
def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_with_end_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test agent logs retrieval when conversation is from end user. Test agent logs retrieval when conversation is from end user.
""" """
@ -381,8 +381,6 @@ class TestAgentService:
# Create test data # Create test data
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create end user # Create end user
end_user = EndUser( end_user = EndUser(
id=fake.uuid4(), id=fake.uuid4(),
@ -393,8 +391,8 @@ class TestAgentService:
session_id=fake.uuid4(), session_id=fake.uuid4(),
name=fake.name(), name=fake.name(),
) )
db.session.add(end_user) db_session_with_containers.add(end_user)
db.session.commit() db_session_with_containers.commit()
# Create conversation with end user # Create conversation with end user
conversation = Conversation( conversation = Conversation(
@ -408,8 +406,8 @@ class TestAgentService:
mode="chat", mode="chat",
from_source="api", from_source="api",
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.commit() db_session_with_containers.commit()
# Create app model config # Create app model config
app_model_config = AppModelConfig( app_model_config = AppModelConfig(
@ -421,12 +419,12 @@ class TestAgentService:
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
) )
app_model_config.id = fake.uuid4() app_model_config.id = fake.uuid4()
db.session.add(app_model_config) db_session_with_containers.add(app_model_config)
db.session.commit() db_session_with_containers.commit()
# Update conversation with app model config # Update conversation with app model config
conversation.app_model_config_id = app_model_config.id conversation.app_model_config_id = app_model_config.id
db.session.commit() db_session_with_containers.commit()
# Create message # Create message
message = Message( message = Message(
@ -447,8 +445,8 @@ class TestAgentService:
currency="USD", currency="USD",
from_source="api", from_source="api",
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -457,7 +455,9 @@ class TestAgentService:
assert result is not None assert result is not None
assert result["meta"]["executor"] == end_user.name assert result["meta"]["executor"] == end_user.name
def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_with_unknown_executor(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test agent logs retrieval when executor is unknown. Test agent logs retrieval when executor is unknown.
""" """
@ -466,8 +466,6 @@ class TestAgentService:
# Create test data # Create test data
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create conversation with non-existent account # Create conversation with non-existent account
conversation = Conversation( conversation = Conversation(
id=fake.uuid4(), id=fake.uuid4(),
@ -480,8 +478,8 @@ class TestAgentService:
mode="chat", mode="chat",
from_source="api", from_source="api",
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.commit() db_session_with_containers.commit()
# Create app model config # Create app model config
app_model_config = AppModelConfig( app_model_config = AppModelConfig(
@ -493,12 +491,12 @@ class TestAgentService:
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
) )
app_model_config.id = fake.uuid4() app_model_config.id = fake.uuid4()
db.session.add(app_model_config) db_session_with_containers.add(app_model_config)
db.session.commit() db_session_with_containers.commit()
# Update conversation with app model config # Update conversation with app model config
conversation.app_model_config_id = app_model_config.id conversation.app_model_config_id = app_model_config.id
db.session.commit() db_session_with_containers.commit()
# Create message # Create message
message = Message( message = Message(
@ -519,8 +517,8 @@ class TestAgentService:
currency="USD", currency="USD",
from_source="api", from_source="api",
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -529,7 +527,9 @@ class TestAgentService:
assert result is not None assert result is not None
assert result["meta"]["executor"] == "Unknown" assert result["meta"]["executor"] == "Unknown"
def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_with_tool_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test agent logs retrieval with tool errors. Test agent logs retrieval with tool errors.
""" """
@ -539,8 +539,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with tool error # Create agent thought with tool error
thought_with_error = MessageAgentThought( thought_with_error = MessageAgentThought(
message_id=message.id, message_id=message.id,
@ -564,8 +562,8 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(thought_with_error) db_session_with_containers.add(thought_with_error)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -580,7 +578,7 @@ class TestAgentService:
assert tool_call["error"] == "Tool execution failed" assert tool_call["error"] == "Tool execution failed"
def test_get_agent_logs_without_agent_thoughts( def test_get_agent_logs_without_agent_thoughts(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test agent logs retrieval when message has no agent thoughts. Test agent logs retrieval when message has no agent thoughts.
@ -600,7 +598,7 @@ class TestAgentService:
assert len(result["iterations"]) == 0 assert len(result["iterations"]) == 0
def test_get_agent_logs_app_model_config_not_found( def test_get_agent_logs_app_model_config_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when app model config is not found. Test error handling when app model config is not found.
@ -610,11 +608,9 @@ class TestAgentService:
# Create test data # Create test data
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Remove app model config to test error handling # Remove app model config to test error handling
app.app_model_config_id = None app.app_model_config_id = None
db.session.commit() db_session_with_containers.commit()
# Create conversation without app model config # Create conversation without app model config
conversation = Conversation( conversation = Conversation(
@ -629,8 +625,8 @@ class TestAgentService:
from_source="api", from_source="api",
app_model_config_id=None, # Explicitly set to None app_model_config_id=None, # Explicitly set to None
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.commit() db_session_with_containers.commit()
# Create message # Create message
message = Message( message = Message(
@ -651,15 +647,15 @@ class TestAgentService:
currency="USD", currency="USD",
from_source="api", from_source="api",
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
with pytest.raises(ValueError, match="App model config not found"): with pytest.raises(ValueError, match="App model config not found"):
AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
def test_get_agent_logs_agent_config_not_found( def test_get_agent_logs_agent_config_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when agent config is not found. Test error handling when agent config is not found.
@ -677,7 +673,9 @@ class TestAgentService:
with pytest.raises(ValueError, match="Agent config not found"): with pytest.raises(ValueError, match="Agent config not found"):
AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies): def test_list_agent_providers_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful listing of agent providers. Test successful listing of agent providers.
""" """
@ -698,7 +696,7 @@ class TestAgentService:
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id)) mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id))
def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of specific agent provider. Test successful retrieval of specific agent provider.
""" """
@ -720,7 +718,9 @@ class TestAgentService:
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name) mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name)
def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_provider_plugin_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test error handling when plugin daemon client raises an error. Test error handling when plugin daemon client raises an error.
""" """
@ -741,7 +741,7 @@ class TestAgentService:
AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
def test_get_agent_logs_with_complex_tool_data( def test_get_agent_logs_with_complex_tool_data(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test agent logs retrieval with complex tool data and multiple tools. Test agent logs retrieval with complex tool data and multiple tools.
@ -752,8 +752,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with multiple tools # Create agent thought with multiple tools
complex_thought = MessageAgentThought( complex_thought = MessageAgentThought(
message_id=message.id, message_id=message.id,
@ -799,8 +797,8 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(complex_thought) db_session_with_containers.add(complex_thought)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -831,7 +829,7 @@ class TestAgentService:
assert tool_calls[2]["status"] == "success" assert tool_calls[2]["status"] == "success"
assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon
def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test agent logs retrieval with message files and agent thought files. Test agent logs retrieval with message files and agent thought files.
""" """
@ -842,7 +840,6 @@ class TestAgentService:
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from dify_graph.file import FileTransferMethod, FileType from dify_graph.file import FileTransferMethod, FileType
from extensions.ext_database import db
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
# Add files to message # Add files to message
@ -867,9 +864,9 @@ class TestAgentService:
created_by_role=CreatorUserRole.ACCOUNT, created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(message_file1) db_session_with_containers.add(message_file1)
db.session.add(message_file2) db_session_with_containers.add(message_file2)
db.session.commit() db_session_with_containers.commit()
# Create agent thought with files # Create agent thought with files
thought_with_files = MessageAgentThought( thought_with_files = MessageAgentThought(
@ -895,8 +892,8 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(thought_with_files) db_session_with_containers.add(thought_with_files)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -912,7 +909,7 @@ class TestAgentService:
assert "file2" in iterations[0]["files"] assert "file2" in iterations[0]["files"]
def test_get_agent_logs_with_different_timezone( def test_get_agent_logs_with_different_timezone(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test agent logs retrieval with different timezone settings. Test agent logs retrieval with different timezone settings.
@ -938,7 +935,9 @@ class TestAgentService:
assert "T" in start_time # ISO format assert "T" in start_time # ISO format
assert "+08:00" in start_time or "Z" in start_time # Timezone offset assert "+08:00" in start_time or "Z" in start_time # Timezone offset
def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_with_empty_tool_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test agent logs retrieval with empty tool data. Test agent logs retrieval with empty tool data.
""" """
@ -948,8 +947,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with empty tool data # Create agent thought with empty tool data
empty_thought = MessageAgentThought( empty_thought = MessageAgentThought(
message_id=message.id, message_id=message.id,
@ -964,8 +961,8 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(empty_thought) db_session_with_containers.add(empty_thought)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -979,7 +976,9 @@ class TestAgentService:
tool_calls = iterations[0]["tool_calls"] tool_calls = iterations[0]["tool_calls"]
assert len(tool_calls) == 0 # No tools to process assert len(tool_calls) == 0 # No tools to process
def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies): def test_get_agent_logs_with_malformed_json(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test agent logs retrieval with malformed JSON data in tool fields. Test agent logs retrieval with malformed JSON data in tool fields.
""" """
@ -989,8 +988,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with malformed JSON # Create agent thought with malformed JSON
malformed_thought = MessageAgentThought( malformed_thought = MessageAgentThought(
message_id=message.id, message_id=message.id,
@ -1005,8 +1002,8 @@ class TestAgentService:
created_by_role="account", created_by_role="account",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(malformed_thought) db_session_with_containers.add(malformed_thought)
db.session.commit() db_session_with_containers.commit()
# Execute the method under test # Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from models import Account from models import Account
@ -52,7 +53,7 @@ class TestAnnotationService:
"current_user": mock_user, "current_user": mock_user,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -115,11 +116,10 @@ class TestAnnotationService:
tenant_id, tenant_id,
) )
def _create_test_conversation(self, app, account, fake): def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake):
""" """
Helper method to create a test conversation with all required fields. Helper method to create a test conversation with all required fields.
""" """
from extensions.ext_database import db
from models.model import Conversation from models.model import Conversation
conversation = Conversation( conversation = Conversation(
@ -141,17 +141,16 @@ class TestAnnotationService:
from_account_id=account.id, from_account_id=account.id,
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.flush() db_session_with_containers.flush()
return conversation return conversation
def _create_test_message(self, app, conversation, account, fake): def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake):
""" """
Helper method to create a test message with all required fields. Helper method to create a test message with all required fields.
""" """
import json import json
from extensions.ext_database import db
from models.model import Message from models.model import Message
message = Message( message = Message(
@ -180,12 +179,12 @@ class TestAnnotationService:
from_account_id=account.id, from_account_id=account.id,
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
return message return message
def test_insert_app_annotation_directly_success( def test_insert_app_annotation_directly_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful direct insertion of app annotation. Test successful direct insertion of app annotation.
@ -211,9 +210,8 @@ class TestAnnotationService:
assert annotation.id is not None assert annotation.id is not None
# Verify annotation was saved to database # Verify annotation was saved to database
from extensions.ext_database import db
db.session.refresh(annotation) db_session_with_containers.refresh(annotation)
assert annotation.id is not None assert annotation.id is not None
# Verify add_annotation_to_index_task was called (when annotation setting exists) # Verify add_annotation_to_index_task was called (when annotation setting exists)
@ -221,7 +219,7 @@ class TestAnnotationService:
mock_external_service_dependencies["add_task"].delay.assert_not_called() mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_insert_app_annotation_directly_requires_question( def test_insert_app_annotation_directly_requires_question(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Question must be provided when inserting annotations directly. Question must be provided when inserting annotations directly.
@ -238,7 +236,7 @@ class TestAnnotationService:
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
def test_insert_app_annotation_directly_app_not_found( def test_insert_app_annotation_directly_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test direct insertion of app annotation when app is not found. Test direct insertion of app annotation when app is not found.
@ -260,7 +258,7 @@ class TestAnnotationService:
AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id)
def test_update_app_annotation_directly_success( def test_update_app_annotation_directly_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful direct update of app annotation. Test successful direct update of app annotation.
@ -298,7 +296,7 @@ class TestAnnotationService:
mock_external_service_dependencies["update_task"].delay.assert_not_called() mock_external_service_dependencies["update_task"].delay.assert_not_called()
def test_up_insert_app_annotation_from_message_new( def test_up_insert_app_annotation_from_message_new(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creating new annotation from message. Test creating new annotation from message.
@ -307,8 +305,8 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message first # Create a conversation and message first
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Setup annotation data with message_id # Setup annotation data with message_id
annotation_args = { annotation_args = {
@ -333,7 +331,7 @@ class TestAnnotationService:
mock_external_service_dependencies["add_task"].delay.assert_not_called() mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_up_insert_app_annotation_from_message_update( def test_up_insert_app_annotation_from_message_update(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test updating existing annotation from message. Test updating existing annotation from message.
@ -342,8 +340,8 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message first # Create a conversation and message first
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial annotation # Create initial annotation
initial_args = { initial_args = {
@ -373,7 +371,7 @@ class TestAnnotationService:
mock_external_service_dependencies["add_task"].delay.assert_not_called() mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_up_insert_app_annotation_from_message_app_not_found( def test_up_insert_app_annotation_from_message_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creating annotation from message when app is not found. Test creating annotation from message when app is not found.
@ -395,7 +393,7 @@ class TestAnnotationService:
AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id)
def test_get_annotation_list_by_app_id_success( def test_get_annotation_list_by_app_id_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful retrieval of annotation list by app ID. Test successful retrieval of annotation list by app ID.
@ -428,7 +426,7 @@ class TestAnnotationService:
assert annotation.account_id == account.id assert annotation.account_id == account.id
def test_get_annotation_list_by_app_id_with_keyword( def test_get_annotation_list_by_app_id_with_keyword(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test retrieval of annotation list with keyword search. Test retrieval of annotation list with keyword search.
@ -462,7 +460,7 @@ class TestAnnotationService:
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
r""" r"""
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
@ -534,7 +532,7 @@ class TestAnnotationService:
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
def test_get_annotation_list_by_app_id_app_not_found( def test_get_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test retrieval of annotation list when app is not found. Test retrieval of annotation list when app is not found.
@ -549,7 +547,9 @@ class TestAnnotationService:
with pytest.raises(NotFound, match="App not found"): with pytest.raises(NotFound, match="App not found"):
AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="")
def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful deletion of app annotation. Test successful deletion of app annotation.
""" """
@ -568,16 +568,19 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(app.id, annotation_id) AppAnnotationService.delete_app_annotation(app.id, annotation_id)
# Verify annotation was deleted # Verify annotation was deleted
from extensions.ext_database import db
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() deleted_annotation = (
db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
)
assert deleted_annotation is None assert deleted_annotation is None
# Verify delete_annotation_index_task was called (when annotation setting exists) # Verify delete_annotation_index_task was called (when annotation setting exists)
# Note: In this test, no annotation setting exists, so task should not be called # Note: In this test, no annotation setting exists, so task should not be called
mock_external_service_dependencies["delete_task"].delay.assert_not_called() mock_external_service_dependencies["delete_task"].delay.assert_not_called()
def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_app_annotation_app_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test deletion of app annotation when app is not found. Test deletion of app annotation when app is not found.
""" """
@ -593,7 +596,7 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id)
def test_delete_app_annotation_annotation_not_found( def test_delete_app_annotation_annotation_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test deletion of app annotation when annotation is not found. Test deletion of app annotation when annotation is not found.
@ -606,7 +609,9 @@ class TestAnnotationService:
with pytest.raises(NotFound, match="Annotation not found"): with pytest.raises(NotFound, match="Annotation not found"):
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id)
def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): def test_enable_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful enabling of app annotation. Test successful enabling of app annotation.
""" """
@ -632,7 +637,9 @@ class TestAnnotationService:
# Verify task was called # Verify task was called
mock_external_service_dependencies["enable_task"].delay.assert_called_once() mock_external_service_dependencies["enable_task"].delay.assert_called_once()
def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): def test_disable_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful disabling of app annotation. Test successful disabling of app annotation.
""" """
@ -651,7 +658,9 @@ class TestAnnotationService:
# Verify task was called # Verify task was called
mock_external_service_dependencies["disable_task"].delay.assert_called_once() mock_external_service_dependencies["disable_task"].delay.assert_called_once()
def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): def test_enable_app_annotation_cached_job(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test enabling app annotation when job is already cached. Test enabling app annotation when job is already cached.
""" """
@ -685,7 +694,9 @@ class TestAnnotationService:
# Clean up # Clean up
redis_client.delete(enable_app_annotation_key) redis_client.delete(enable_app_annotation_key)
def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_annotation_hit_histories_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of annotation hit histories. Test successful retrieval of annotation hit histories.
""" """
@ -728,7 +739,9 @@ class TestAnnotationService:
assert history.app_id == app.id assert history.app_id == app.id
assert history.account_id == account.id assert history.account_id == account.id
def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): def test_add_annotation_history_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful addition of annotation history. Test successful addition of annotation history.
""" """
@ -763,16 +776,15 @@ class TestAnnotationService:
) )
# Verify hit count was incremented # Verify hit count was incremented
from extensions.ext_database import db
db.session.refresh(annotation) db_session_with_containers.refresh(annotation)
assert annotation.hit_count == initial_hit_count + 1 assert annotation.hit_count == initial_hit_count + 1
# Verify history was created # Verify history was created
from models.model import AppAnnotationHitHistory from models.model import AppAnnotationHitHistory
history = ( history = (
db.session.query(AppAnnotationHitHistory) db_session_with_containers.query(AppAnnotationHitHistory)
.where( .where(
AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id
) )
@ -786,7 +798,9 @@ class TestAnnotationService:
assert history.score == score assert history.score == score
assert history.source == "console" assert history.source == "console"
def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_annotation_by_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of annotation by ID. Test successful retrieval of annotation by ID.
""" """
@ -811,7 +825,9 @@ class TestAnnotationService:
assert retrieved_annotation.content == annotation_args["answer"] assert retrieved_annotation.content == annotation_args["answer"]
assert retrieved_annotation.account_id == account.id assert retrieved_annotation.account_id == account.id
def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): def test_batch_import_app_annotations_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful batch import of app annotations. Test successful batch import of app annotations.
""" """
@ -854,7 +870,7 @@ class TestAnnotationService:
mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() mock_external_service_dependencies["batch_import_task"].delay.assert_called_once()
def test_batch_import_app_annotations_empty_file( def test_batch_import_app_annotations_empty_file(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test batch import with empty CSV file. Test batch import with empty CSV file.
@ -889,7 +905,7 @@ class TestAnnotationService:
assert "empty" in result["error_msg"].lower() assert "empty" in result["error_msg"].lower()
def test_batch_import_app_annotations_quota_exceeded( def test_batch_import_app_annotations_quota_exceeded(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test batch import when quota is exceeded. Test batch import when quota is exceeded.
@ -935,7 +951,7 @@ class TestAnnotationService:
assert "limit" in result["error_msg"].lower() assert "limit" in result["error_msg"].lower()
def test_get_app_annotation_setting_by_app_id_enabled( def test_get_app_annotation_setting_by_app_id_enabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting enabled app annotation setting by app ID. Test getting enabled app annotation setting by app ID.
@ -944,7 +960,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting # Create annotation setting
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting from models.model import AppAnnotationSetting
@ -956,8 +971,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}", collection_name=f"annotation_collection_{fake.uuid4()}",
) )
collection_binding.id = str(fake.uuid4()) collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding) db_session_with_containers.add(collection_binding)
db.session.flush() db_session_with_containers.flush()
# Create annotation setting # Create annotation setting
annotation_setting = AppAnnotationSetting( annotation_setting = AppAnnotationSetting(
@ -967,8 +982,8 @@ class TestAnnotationService:
created_user_id=account.id, created_user_id=account.id,
updated_user_id=account.id, updated_user_id=account.id,
) )
db.session.add(annotation_setting) db_session_with_containers.add(annotation_setting)
db.session.commit() db_session_with_containers.commit()
# Get annotation setting # Get annotation setting
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id)
@ -981,7 +996,7 @@ class TestAnnotationService:
assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002"
def test_get_app_annotation_setting_by_app_id_disabled( def test_get_app_annotation_setting_by_app_id_disabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting disabled app annotation setting by app ID. Test getting disabled app annotation setting by app ID.
@ -996,7 +1011,7 @@ class TestAnnotationService:
assert result["enabled"] is False assert result["enabled"] is False
def test_update_app_annotation_setting_success( def test_update_app_annotation_setting_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful update of app annotation setting. Test successful update of app annotation setting.
@ -1005,7 +1020,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first # Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting from models.model import AppAnnotationSetting
@ -1017,8 +1031,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}", collection_name=f"annotation_collection_{fake.uuid4()}",
) )
collection_binding.id = str(fake.uuid4()) collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding) db_session_with_containers.add(collection_binding)
db.session.flush() db_session_with_containers.flush()
# Create annotation setting # Create annotation setting
annotation_setting = AppAnnotationSetting( annotation_setting = AppAnnotationSetting(
@ -1028,8 +1042,8 @@ class TestAnnotationService:
created_user_id=account.id, created_user_id=account.id,
updated_user_id=account.id, updated_user_id=account.id,
) )
db.session.add(annotation_setting) db_session_with_containers.add(annotation_setting)
db.session.commit() db_session_with_containers.commit()
# Update annotation setting # Update annotation setting
update_args = { update_args = {
@ -1046,11 +1060,11 @@ class TestAnnotationService:
assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002"
# Verify database was updated # Verify database was updated
db.session.refresh(annotation_setting) db_session_with_containers.refresh(annotation_setting)
assert annotation_setting.score_threshold == 0.9 assert annotation_setting.score_threshold == 0.9
def test_export_annotation_list_by_app_id_success( def test_export_annotation_list_by_app_id_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful export of annotation list by app ID. Test successful export of annotation list by app ID.
@ -1083,7 +1097,7 @@ class TestAnnotationService:
assert annotation.created_at <= exported_annotations[i - 1].created_at assert annotation.created_at <= exported_annotations[i - 1].created_at
def test_export_annotation_list_by_app_id_app_not_found( def test_export_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test export of annotation list when app is not found. Test export of annotation list when app is not found.
@ -1099,7 +1113,7 @@ class TestAnnotationService:
AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id)
def test_insert_app_annotation_directly_with_setting_success( def test_insert_app_annotation_directly_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful direct insertion of app annotation with annotation setting enabled. Test successful direct insertion of app annotation with annotation setting enabled.
@ -1108,7 +1122,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first # Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting from models.model import AppAnnotationSetting
@ -1120,8 +1133,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}", collection_name=f"annotation_collection_{fake.uuid4()}",
) )
collection_binding.id = str(fake.uuid4()) collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding) db_session_with_containers.add(collection_binding)
db.session.flush() db_session_with_containers.flush()
# Create annotation setting # Create annotation setting
annotation_setting = AppAnnotationSetting( annotation_setting = AppAnnotationSetting(
@ -1131,8 +1144,8 @@ class TestAnnotationService:
created_user_id=account.id, created_user_id=account.id,
updated_user_id=account.id, updated_user_id=account.id,
) )
db.session.add(annotation_setting) db_session_with_containers.add(annotation_setting)
db.session.commit() db_session_with_containers.commit()
# Setup annotation data # Setup annotation data
annotation_args = { annotation_args = {
@ -1161,7 +1174,7 @@ class TestAnnotationService:
assert call_args[4] == collection_binding.id # collection_binding_id assert call_args[4] == collection_binding.id # collection_binding_id
def test_update_app_annotation_directly_with_setting_success( def test_update_app_annotation_directly_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful direct update of app annotation with annotation setting enabled. Test successful direct update of app annotation with annotation setting enabled.
@ -1170,7 +1183,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first # Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting from models.model import AppAnnotationSetting
@ -1182,8 +1194,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}", collection_name=f"annotation_collection_{fake.uuid4()}",
) )
collection_binding.id = str(fake.uuid4()) collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding) db_session_with_containers.add(collection_binding)
db.session.flush() db_session_with_containers.flush()
# Create annotation setting # Create annotation setting
annotation_setting = AppAnnotationSetting( annotation_setting = AppAnnotationSetting(
@ -1193,8 +1205,8 @@ class TestAnnotationService:
created_user_id=account.id, created_user_id=account.id,
updated_user_id=account.id, updated_user_id=account.id,
) )
db.session.add(annotation_setting) db_session_with_containers.add(annotation_setting)
db.session.commit() db_session_with_containers.commit()
# First, create an annotation # First, create an annotation
original_args = { original_args = {
@ -1234,7 +1246,7 @@ class TestAnnotationService:
assert call_args[4] == collection_binding.id # collection_binding_id assert call_args[4] == collection_binding.id # collection_binding_id
def test_delete_app_annotation_with_setting_success( def test_delete_app_annotation_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful deletion of app annotation with annotation setting enabled. Test successful deletion of app annotation with annotation setting enabled.
@ -1243,7 +1255,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first # Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting from models.model import AppAnnotationSetting
@ -1255,8 +1266,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}", collection_name=f"annotation_collection_{fake.uuid4()}",
) )
collection_binding.id = str(fake.uuid4()) collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding) db_session_with_containers.add(collection_binding)
db.session.flush() db_session_with_containers.flush()
# Create annotation setting # Create annotation setting
annotation_setting = AppAnnotationSetting( annotation_setting = AppAnnotationSetting(
@ -1267,8 +1278,8 @@ class TestAnnotationService:
updated_user_id=account.id, updated_user_id=account.id,
) )
db.session.add(annotation_setting) db_session_with_containers.add(annotation_setting)
db.session.commit() db_session_with_containers.commit()
# Create an annotation first # Create an annotation first
annotation_args = { annotation_args = {
@ -1285,7 +1296,9 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(app.id, annotation_id) AppAnnotationService.delete_app_annotation(app.id, annotation_id)
# Verify annotation was deleted # Verify annotation was deleted
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() deleted_annotation = (
db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
)
assert deleted_annotation is None assert deleted_annotation is None
# Verify delete_annotation_index_task was called # Verify delete_annotation_index_task was called
@ -1297,7 +1310,7 @@ class TestAnnotationService:
assert call_args[3] == collection_binding.id # collection_binding_id assert call_args[3] == collection_binding.id # collection_binding_id
def test_up_insert_app_annotation_from_message_with_setting_success( def test_up_insert_app_annotation_from_message_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creating annotation from message with annotation setting enabled. Test creating annotation from message with annotation setting enabled.
@ -1306,7 +1319,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first # Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting from models.model import AppAnnotationSetting
@ -1318,8 +1330,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}", collection_name=f"annotation_collection_{fake.uuid4()}",
) )
collection_binding.id = str(fake.uuid4()) collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding) db_session_with_containers.add(collection_binding)
db.session.flush() db_session_with_containers.flush()
# Create annotation setting # Create annotation setting
annotation_setting = AppAnnotationSetting( annotation_setting = AppAnnotationSetting(
@ -1329,12 +1341,12 @@ class TestAnnotationService:
created_user_id=account.id, created_user_id=account.id,
updated_user_id=account.id, updated_user_id=account.id,
) )
db.session.add(annotation_setting) db_session_with_containers.add(annotation_setting)
db.session.commit() db_session_with_containers.commit()
# Create a conversation and message first # Create a conversation and message first
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Setup annotation data with message_id # Setup annotation data with message_id
annotation_args = { annotation_args = {

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
@ -31,7 +32,7 @@ class TestAPIBasedExtensionService:
"requestor_instance": mock_requestor_instance, "requestor_instance": mock_requestor_instance,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -61,7 +62,7 @@ class TestAPIBasedExtensionService:
return account, tenant return account, tenant
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful saving of API-based extension. Test successful saving of API-based extension.
""" """
@ -90,15 +91,16 @@ class TestAPIBasedExtensionService:
assert saved_extension.created_at is not None assert saved_extension.created_at is not None
# Verify extension was saved to database # Verify extension was saved to database
from extensions.ext_database import db
db.session.refresh(saved_extension) db_session_with_containers.refresh(saved_extension)
assert saved_extension.id is not None assert saved_extension.id is not None
# Verify ping connection was called # Verify ping connection was called
mock_external_service_dependencies["requestor_instance"].request.assert_called_once() mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_validation_errors(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test validation errors when saving extension with invalid data. Test validation errors when saving extension with invalid data.
""" """
@ -132,7 +134,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="api_key must not be empty"): with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data) APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_all_by_tenant_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of all extensions by tenant ID. Test successful retrieval of all extensions by tenant ID.
""" """
@ -169,7 +173,7 @@ class TestAPIBasedExtensionService:
# Verify descending order (newer first) # Verify descending order (newer first)
assert extension.created_at <= extension_list[i - 1].created_at assert extension.created_at <= extension_list[i - 1].created_at
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of extension by tenant ID and extension ID. Test successful retrieval of extension by tenant ID and extension ID.
""" """
@ -200,7 +204,9 @@ class TestAPIBasedExtensionService:
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
assert retrieved_extension.created_at is not None assert retrieved_extension.created_at is not None
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_get_with_tenant_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test retrieval of extension when extension is not found. Test retrieval of extension when extension is not found.
""" """
@ -214,7 +220,7 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="API based extension is not found"): with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful deletion of extension. Test successful deletion of extension.
""" """
@ -238,12 +244,15 @@ class TestAPIBasedExtensionService:
APIBasedExtensionService.delete(created_extension) APIBasedExtensionService.delete(created_extension)
# Verify extension was deleted # Verify extension was deleted
from extensions.ext_database import db
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() deleted_extension = (
db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
)
assert deleted_extension is None assert deleted_extension is None
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test validation error when saving extension with duplicate name. Test validation error when saving extension with duplicate name.
""" """
@ -272,7 +281,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="name must be unique, it is already existed"): with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(extension_data2) APIBasedExtensionService.save(extension_data2)
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_update_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful update of existing extension. Test successful update of existing extension.
""" """
@ -329,7 +340,9 @@ class TestAPIBasedExtensionService:
assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_endpoint == new_endpoint
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_connection_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test connection error when saving extension with invalid endpoint. Test connection error when saving extension with invalid endpoint.
""" """
@ -356,7 +369,7 @@ class TestAPIBasedExtensionService:
APIBasedExtensionService.save(extension_data) APIBasedExtensionService.save(extension_data)
def test_save_extension_invalid_api_key_length( def test_save_extension_invalid_api_key_length(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test validation error when saving extension with API key that is too short. Test validation error when saving extension with API key that is too short.
@ -378,7 +391,7 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="api_key must be at least 5 characters"): with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data) APIBasedExtensionService.save(extension_data)
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test validation errors when saving extension with empty required fields. Test validation errors when saving extension with empty required fields.
""" """
@ -412,7 +425,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="api_key must not be empty"): with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data) APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): def test_get_all_by_tenant_id_empty_list(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test retrieval of extensions when no extensions exist for tenant. Test retrieval of extensions when no extensions exist for tenant.
""" """
@ -428,7 +443,9 @@ class TestAPIBasedExtensionService:
assert len(extension_list) == 0 assert len(extension_list) == 0
assert extension_list == [] assert extension_list == []
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_invalid_ping_response(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test validation error when ping response is invalid. Test validation error when ping response is invalid.
""" """
@ -452,7 +469,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="{'result': 'invalid'}"): with pytest.raises(ValueError, match="{'result': 'invalid'}"):
APIBasedExtensionService.save(extension_data) APIBasedExtensionService.save(extension_data)
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_missing_ping_result(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test validation error when ping response is missing result field. Test validation error when ping response is missing result field.
""" """
@ -476,7 +495,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="{'status': 'ok'}"): with pytest.raises(ValueError, match="{'status': 'ok'}"):
APIBasedExtensionService.save(extension_data) APIBasedExtensionService.save(extension_data)
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): def test_get_with_tenant_id_wrong_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test retrieval of extension when tenant ID doesn't match. Test retrieval of extension when tenant ID doesn't match.
""" """

View File

@ -3,6 +3,7 @@ from unittest.mock import ANY, MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import EndUser from models.model import EndUser
@ -118,7 +119,9 @@ class TestAppGenerateService:
"global_dify_config": mock_global_dify_config, "global_dify_config": mock_global_dify_config,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"): def _create_test_app_and_account(
self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat"
):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -169,7 +172,7 @@ class TestAppGenerateService:
return app, account return app, account
def _create_test_workflow(self, db_session_with_containers, app): def _create_test_workflow(self, db_session_with_containers: Session, app):
""" """
Helper method to create a test workflow for testing. Helper method to create a test workflow for testing.
@ -191,14 +194,14 @@ class TestAppGenerateService:
status="published", status="published",
) )
from extensions.ext_database import db db_session_with_containers.add(workflow)
db_session_with_containers.commit()
db.session.add(workflow)
db.session.commit()
return workflow return workflow
def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_completion_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful generation for completion mode app. Test successful generation for completion mode app.
""" """
@ -226,7 +229,7 @@ class TestAppGenerateService:
mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once() mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once()
def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful generation for chat mode app. Test successful generation for chat mode app.
""" """
@ -250,7 +253,9 @@ class TestAppGenerateService:
mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once() mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_agent_chat_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful generation for agent chat mode app. Test successful generation for agent chat mode app.
""" """
@ -274,7 +279,9 @@ class TestAppGenerateService:
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_advanced_chat_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful generation for advanced chat mode app. Test successful generation for advanced chat mode app.
""" """
@ -300,7 +307,9 @@ class TestAppGenerateService:
"advanced_chat_generator" "advanced_chat_generator"
].return_value.convert_to_event_stream.assert_called_once() ].return_value.convert_to_event_stream.assert_called_once()
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_workflow_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful generation for workflow mode app. Test successful generation for workflow mode app.
""" """
@ -324,7 +333,9 @@ class TestAppGenerateService:
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once()
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_specific_workflow_id(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test generation with a specific workflow ID. Test generation with a specific workflow ID.
""" """
@ -355,7 +366,9 @@ class TestAppGenerateService:
"workflow_service" "workflow_service"
].return_value.get_published_workflow_by_id.assert_called_once() ].return_value.get_published_workflow_by_id.assert_called_once()
def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_debugger_invoke_from(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test generation with debugger invoke from. Test generation with debugger invoke from.
""" """
@ -378,7 +391,9 @@ class TestAppGenerateService:
# Verify draft workflow was fetched for debugger # Verify draft workflow was fetched for debugger
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once() mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once()
def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_non_streaming_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test generation with non-streaming mode. Test generation with non-streaming mode.
""" """
@ -401,7 +416,7 @@ class TestAppGenerateService:
# Verify rate limit exit was called for non-streaming mode # Verify rate limit exit was called for non-streaming mode
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test generation with EndUser instead of Account. Test generation with EndUser instead of Account.
""" """
@ -421,10 +436,8 @@ class TestAppGenerateService:
session_id=fake.uuid4(), session_id=fake.uuid4(),
) )
from extensions.ext_database import db db_session_with_containers.add(end_user)
db_session_with_containers.commit()
db.session.add(end_user)
db.session.commit()
# Setup test arguments # Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
@ -438,7 +451,7 @@ class TestAppGenerateService:
assert result == ["test_response"] assert result == ["test_response"]
def test_generate_with_billing_enabled_sandbox_plan( def test_generate_with_billing_enabled_sandbox_plan(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test generation with billing enabled and sandbox plan. Test generation with billing enabled and sandbox plan.
@ -466,7 +479,9 @@ class TestAppGenerateService:
# Verify billing service was called to consume quota # Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test generation with invalid app mode. Test generation with invalid app mode.
""" """
@ -491,7 +506,7 @@ class TestAppGenerateService:
assert "Invalid app mode" in str(exc_info.value) assert "Invalid app mode" in str(exc_info.value)
def test_generate_with_workflow_id_format_error( def test_generate_with_workflow_id_format_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test generation with invalid workflow ID format. Test generation with invalid workflow ID format.
@ -518,7 +533,7 @@ class TestAppGenerateService:
assert "Invalid workflow_id format" in str(exc_info.value) assert "Invalid workflow_id format" in str(exc_info.value)
def test_generate_with_workflow_not_found_error( def test_generate_with_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test generation when workflow is not found. Test generation when workflow is not found.
@ -552,7 +567,7 @@ class TestAppGenerateService:
assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value) assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value)
def test_generate_with_workflow_not_initialized_error( def test_generate_with_workflow_not_initialized_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test generation when workflow is not initialized for debugger. Test generation when workflow is not initialized for debugger.
@ -578,7 +593,7 @@ class TestAppGenerateService:
assert "Workflow not initialized" in str(exc_info.value) assert "Workflow not initialized" in str(exc_info.value)
def test_generate_with_workflow_not_published_error( def test_generate_with_workflow_not_published_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test generation when workflow is not published for non-debugger. Test generation when workflow is not published for non-debugger.
@ -604,7 +619,7 @@ class TestAppGenerateService:
assert "Workflow not published" in str(exc_info.value) assert "Workflow not published" in str(exc_info.value)
def test_generate_single_iteration_advanced_chat_success( def test_generate_single_iteration_advanced_chat_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful single iteration generation for advanced chat mode. Test successful single iteration generation for advanced chat mode.
@ -631,7 +646,7 @@ class TestAppGenerateService:
].return_value.single_iteration_generate.assert_called_once() ].return_value.single_iteration_generate.assert_called_once()
def test_generate_single_iteration_workflow_success( def test_generate_single_iteration_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful single iteration generation for workflow mode. Test successful single iteration generation for workflow mode.
@ -658,7 +673,7 @@ class TestAppGenerateService:
].return_value.single_iteration_generate.assert_called_once() ].return_value.single_iteration_generate.assert_called_once()
def test_generate_single_iteration_invalid_mode( def test_generate_single_iteration_invalid_mode(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test single iteration generation with invalid app mode. Test single iteration generation with invalid app mode.
@ -681,7 +696,7 @@ class TestAppGenerateService:
assert "Invalid app mode" in str(exc_info.value) assert "Invalid app mode" in str(exc_info.value)
def test_generate_single_loop_advanced_chat_success( def test_generate_single_loop_advanced_chat_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful single loop generation for advanced chat mode. Test successful single loop generation for advanced chat mode.
@ -708,7 +723,7 @@ class TestAppGenerateService:
].return_value.single_loop_generate.assert_called_once() ].return_value.single_loop_generate.assert_called_once()
def test_generate_single_loop_workflow_success( def test_generate_single_loop_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful single loop generation for workflow mode. Test successful single loop generation for workflow mode.
@ -732,7 +747,9 @@ class TestAppGenerateService:
# Verify workflow generator was called # Verify workflow generator was called
mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once() mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once()
def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_single_loop_invalid_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test single loop generation with invalid app mode. Test single loop generation with invalid app mode.
""" """
@ -753,7 +770,9 @@ class TestAppGenerateService:
# Verify error message # Verify error message
assert "Invalid app mode" in str(exc_info.value) assert "Invalid app mode" in str(exc_info.value)
def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_more_like_this_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful more like this generation. Test successful more like this generation.
""" """
@ -778,7 +797,7 @@ class TestAppGenerateService:
].return_value.generate_more_like_this.assert_called_once() ].return_value.generate_more_like_this.assert_called_once()
def test_generate_more_like_this_with_end_user( def test_generate_more_like_this_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test more like this generation with EndUser. Test more like this generation with EndUser.
@ -799,10 +818,8 @@ class TestAppGenerateService:
session_id=fake.uuid4(), session_id=fake.uuid4(),
) )
from extensions.ext_database import db db_session_with_containers.add(end_user)
db_session_with_containers.commit()
db.session.add(end_user)
db.session.commit()
message_id = fake.uuid4() message_id = fake.uuid4()
@ -815,7 +832,7 @@ class TestAppGenerateService:
assert result == ["more_like_this_response"] assert result == ["more_like_this_response"]
def test_get_max_active_requests_with_app_limit( def test_get_max_active_requests_with_app_limit(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting max active requests with app-specific limit. Test getting max active requests with app-specific limit.
@ -835,7 +852,7 @@ class TestAppGenerateService:
assert result == 10 assert result == 10
def test_get_max_active_requests_with_config_limit( def test_get_max_active_requests_with_config_limit(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting max active requests with config limit being smaller. Test getting max active requests with config limit being smaller.
@ -856,7 +873,7 @@ class TestAppGenerateService:
assert result <= 100 assert result <= 100
def test_get_max_active_requests_with_zero_limits( def test_get_max_active_requests_with_zero_limits(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting max active requests with zero limits (infinite). Test getting max active requests with zero limits (infinite).
@ -875,7 +892,9 @@ class TestAppGenerateService:
# Verify the result (should return config limit when app limit is 0) # Verify the result (should return config limit when app limit is 0)
assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS
def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_exception_cleanup(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test that rate limit exit is called when an exception occurs. Test that rate limit exit is called when an exception occurs.
""" """
@ -904,7 +923,9 @@ class TestAppGenerateService:
# Verify rate limit exit was called for cleanup # Verify rate limit exit was called for cleanup
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_agent_mode_detection(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test generation with agent mode detection based on app configuration. Test generation with agent mode detection based on app configuration.
""" """
@ -932,7 +953,7 @@ class TestAppGenerateService:
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_different_invoke_from_values( def test_generate_with_different_invoke_from_values(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test generation with different invoke from values. Test generation with different invoke from values.
@ -962,7 +983,7 @@ class TestAppGenerateService:
# Verify the result # Verify the result
assert result == ["test_response"] assert result == ["test_response"]
def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies): def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test generation with complex arguments including files and external trace ID. Test generation with complex arguments including files and external trace ID.
""" """

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from constants.model_template import default_app_templates from constants.model_template import default_app_templates
from models import Account from models import Account
@ -44,7 +45,7 @@ class TestAppService:
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
} }
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app creation with basic parameters. Test successful app creation with basic parameters.
""" """
@ -98,7 +99,9 @@ class TestAppService:
assert app.is_public is False assert app.is_public is False
assert app.is_universal is False assert app.is_universal is False
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): def test_create_app_with_different_modes(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test app creation with different app modes. Test app creation with different app modes.
""" """
@ -141,7 +144,7 @@ class TestAppService:
assert app.tenant_id == tenant.id assert app.tenant_id == tenant.id
assert app.created_by == account.id assert app.created_by == account.id
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app retrieval. Test successful app retrieval.
""" """
@ -189,7 +192,7 @@ class TestAppService:
assert retrieved_app.tenant_id == created_app.tenant_id assert retrieved_app.tenant_id == created_app.tenant_id
assert retrieved_app.created_by == created_app.created_by assert retrieved_app.created_by == created_app.created_by
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful paginated app list retrieval. Test successful paginated app list retrieval.
""" """
@ -243,7 +246,9 @@ class TestAppService:
assert app.tenant_id == tenant.id assert app.tenant_id == tenant.id
assert app.mode == "chat" assert app.mode == "chat"
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): def test_get_paginate_apps_with_filters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test paginated app list with various filters. Test paginated app list with various filters.
""" """
@ -316,7 +321,9 @@ class TestAppService:
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
assert len(my_apps.items) == 1 assert len(my_apps.items) == 1
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): def test_get_paginate_apps_with_tag_filters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test paginated app list with tag filters. Test paginated app list with tag filters.
""" """
@ -386,7 +393,7 @@ class TestAppService:
# Should return None when no apps match tag filter # Should return None when no apps match tag filter
assert paginated_apps is None assert paginated_apps is None
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app update with all fields. Test successful app update with all fields.
""" """
@ -455,7 +462,7 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by assert updated_app.created_by == app.created_by
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app name update. Test successful app name update.
""" """
@ -508,7 +515,7 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by assert updated_app.created_by == app.created_by
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app icon update. Test successful app icon update.
""" """
@ -565,7 +572,9 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by assert updated_app.created_by == app.created_by
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_app_site_status_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful app site status update. Test successful app site status update.
""" """
@ -623,7 +632,9 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by assert updated_app.created_by == app.created_by
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_app_api_status_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful app API status update. Test successful app API status update.
""" """
@ -681,7 +692,9 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by assert updated_app.created_by == app.created_by
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): def test_update_app_site_status_no_change(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test app site status update when status doesn't change. Test app site status update when status doesn't change.
""" """
@ -732,7 +745,7 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by assert updated_app.created_by == app.created_by
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app deletion. Test successful app deletion.
""" """
@ -778,12 +791,13 @@ class TestAppService:
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database # Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first() deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first()
assert deleted_app is None assert deleted_app is None
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_app_with_related_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test app deletion with related data cleanup. Test app deletion with related data cleanup.
""" """
@ -839,12 +853,11 @@ class TestAppService:
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database # Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first() deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first()
assert deleted_app is None assert deleted_app is None
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app metadata retrieval. Test successful app metadata retrieval.
""" """
@ -883,7 +896,7 @@ class TestAppService:
assert "tool_icons" in app_meta assert "tool_icons" in app_meta
# Note: get_app_meta currently only returns tool_icons # Note: get_app_meta currently only returns tool_icons
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app code retrieval by app ID. Test successful app code retrieval by app ID.
""" """
@ -923,7 +936,7 @@ class TestAppService:
assert app_code is not None assert app_code is not None
assert len(app_code) > 0 assert len(app_code) > 0
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful app ID retrieval by app code. Test successful app ID retrieval by app code.
""" """
@ -963,10 +976,9 @@ class TestAppService:
site.status = "normal" site.status = "normal"
site.default_language = "en-US" site.default_language = "en-US"
site.customize_token_strategy = "uuid" site.customize_token_strategy = "uuid"
from extensions.ext_database import db
db.session.add(site) db_session_with_containers.add(site)
db.session.commit() db_session_with_containers.commit()
# Get app ID by code # Get app ID by code
app_id = AppService.get_app_id_by_code(site.code) app_id = AppService.get_app_id_by_code(site.code)
@ -974,7 +986,7 @@ class TestAppService:
# Verify app ID was retrieved correctly # Verify app ID was retrieved correctly
assert app_id == app.id assert app_id == app.id
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test app creation with invalid mode. Test app creation with invalid mode.
""" """
@ -1010,7 +1022,7 @@ class TestAppService:
app_service.create_app(tenant.id, app_args, account) app_service.create_app(tenant.id, app_args, account)
def test_get_apps_with_special_characters_in_name( def test_get_apps_with_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
r""" r"""
Test app retrieval with special characters in name search to verify SQL injection prevention. Test app retrieval with special characters in name search to verify SQL injection prevention.

View File

@ -9,10 +9,10 @@ from unittest.mock import Mock, patch
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from sqlalchemy.orm import Session
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
@ -25,7 +25,9 @@ class DatasetServiceIntegrationDataFactory:
"""Factory for creating real database entities used by integration tests.""" """Factory for creating real database entities used by integration tests."""
@staticmethod @staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: def create_account_with_tenant(
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER
) -> tuple[Account, Tenant]:
"""Create an account and tenant, then bind the account as current tenant member.""" """Create an account and tenant, then bind the account as current tenant member."""
account = Account( account = Account(
email=f"{uuid4()}@example.com", email=f"{uuid4()}@example.com",
@ -34,8 +36,8 @@ class DatasetServiceIntegrationDataFactory:
status="active", status="active",
) )
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add_all([account, tenant]) db_session_with_containers.add_all([account, tenant])
db.session.flush() db_session_with_containers.flush()
join = TenantAccountJoin( join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -43,8 +45,8 @@ class DatasetServiceIntegrationDataFactory:
role=role, role=role,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.flush() db_session_with_containers.flush()
# Keep tenant context on the in-memory user without opening a separate session. # Keep tenant context on the in-memory user without opening a separate session.
account.role = role account.role = role
@ -53,6 +55,7 @@ class DatasetServiceIntegrationDataFactory:
@staticmethod @staticmethod
def create_dataset( def create_dataset(
db_session_with_containers: Session,
tenant_id: str, tenant_id: str,
created_by: str, created_by: str,
name: str = "Test Dataset", name: str = "Test Dataset",
@ -82,12 +85,14 @@ class DatasetServiceIntegrationDataFactory:
collection_binding_id=collection_binding_id, collection_binding_id=collection_binding_id,
chunk_structure=chunk_structure, chunk_structure=chunk_structure,
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.flush() db_session_with_containers.flush()
return dataset return dataset
@staticmethod @staticmethod
def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document: def create_document(
db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt"
) -> Document:
"""Create a document row belonging to the given dataset.""" """Create a document row belonging to the given dataset."""
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
@ -102,8 +107,8 @@ class DatasetServiceIntegrationDataFactory:
indexing_status="completed", indexing_status="completed",
doc_form="text_model", doc_form="text_model",
) )
db.session.add(document) db_session_with_containers.add(document)
db.session.flush() db_session_with_containers.flush()
return document return document
@staticmethod @staticmethod
@ -118,10 +123,10 @@ class DatasetServiceIntegrationDataFactory:
class TestDatasetServiceCreateDataset: class TestDatasetServiceCreateDataset:
"""Integration coverage for DatasetService.create_empty_dataset.""" """Integration coverage for DatasetService.create_empty_dataset."""
def test_create_internal_dataset_basic_success(self, db_session_with_containers): def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session):
"""Create a basic internal dataset with minimal configuration.""" """Create a basic internal dataset with minimal configuration."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
# Act # Act
result = DatasetService.create_empty_dataset( result = DatasetService.create_empty_dataset(
@ -133,17 +138,17 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
created_dataset = db.session.get(Dataset, result.id) created_dataset = db_session_with_containers.get(Dataset, result.id)
assert created_dataset is not None assert created_dataset is not None
assert created_dataset.provider == "vendor" assert created_dataset.provider == "vendor"
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
assert created_dataset.embedding_model_provider is None assert created_dataset.embedding_model_provider is None
assert created_dataset.embedding_model is None assert created_dataset.embedding_model is None
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers): def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session):
"""Create an internal dataset with economy indexing and no embedding model.""" """Create an internal dataset with economy indexing and no embedding model."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
# Act # Act
result = DatasetService.create_empty_dataset( result = DatasetService.create_empty_dataset(
@ -155,15 +160,15 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.indexing_technique == "economy" assert result.indexing_technique == "economy"
assert result.embedding_model_provider is None assert result.embedding_model_provider is None
assert result.embedding_model is None assert result.embedding_model is None
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers): def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session):
"""Create a high-quality dataset and persist embedding model settings.""" """Create a high-quality dataset and persist embedding model settings."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
# Act # Act
@ -179,7 +184,7 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality" assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model_name assert result.embedding_model == embedding_model.model_name
@ -188,11 +193,12 @@ class TestDatasetServiceCreateDataset:
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
) )
def test_create_dataset_duplicate_name_error(self, db_session_with_containers): def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session):
"""Raise duplicate-name error when the same tenant already has the name.""" """Raise duplicate-name error when the same tenant already has the name."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
DatasetServiceIntegrationDataFactory.create_dataset( DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name="Duplicate Dataset", name="Duplicate Dataset",
@ -209,10 +215,10 @@ class TestDatasetServiceCreateDataset:
account=account, account=account,
) )
def test_create_external_dataset_success(self, db_session_with_containers): def test_create_external_dataset_success(self, db_session_with_containers: Session):
"""Create an external dataset and persist external knowledge binding.""" """Create an external dataset and persist external knowledge binding."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
external_knowledge_api_id = str(uuid4()) external_knowledge_api_id = str(uuid4())
external_knowledge_id = "knowledge-123" external_knowledge_id = "knowledge-123"
@ -231,16 +237,16 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
assert result.provider == "external" assert result.provider == "external"
assert binding is not None assert binding is not None
assert binding.external_knowledge_id == external_knowledge_id assert binding.external_knowledge_id == external_knowledge_id
assert binding.external_knowledge_api_id == external_knowledge_api_id assert binding.external_knowledge_api_id == external_knowledge_api_id
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers): def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session):
"""Create a high-quality dataset with retrieval/reranking settings.""" """Create a high-quality dataset with retrieval/reranking settings."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
retrieval_model = RetrievalModel( retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH, search_method=RetrievalMethod.SEMANTIC_SEARCH,
@ -271,14 +277,16 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.retrieval_model == retrieval_model.model_dump() assert result.retrieval_model == retrieval_model.model_dump()
mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0")
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(self, db_session_with_containers): def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(
self, db_session_with_containers: Session
):
"""Create high-quality dataset with explicitly configured embedding model.""" """Create high-quality dataset with explicitly configured embedding model."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
embedding_provider = "openai" embedding_provider = "openai"
embedding_model_name = "text-embedding-3-small" embedding_model_name = "text-embedding-3-small"
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model( embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model(
@ -303,7 +311,7 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality" assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_provider assert result.embedding_model_provider == embedding_provider
assert result.embedding_model == embedding_model_name assert result.embedding_model == embedding_model_name
@ -315,10 +323,10 @@ class TestDatasetServiceCreateDataset:
model=embedding_model_name, model=embedding_model_name,
) )
def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers): def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session):
"""Persist retrieval model settings when creating an internal dataset.""" """Persist retrieval model settings when creating an internal dataset."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
retrieval_model = RetrievalModel( retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH, search_method=RetrievalMethod.SEMANTIC_SEARCH,
reranking_enable=False, reranking_enable=False,
@ -338,13 +346,13 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.retrieval_model == retrieval_model.model_dump() assert result.retrieval_model == retrieval_model.model_dump()
def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers): def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session):
"""Persist canonical custom permission when creating an internal dataset.""" """Persist canonical custom permission when creating an internal dataset."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
# Act # Act
result = DatasetService.create_empty_dataset( result = DatasetService.create_empty_dataset(
@ -357,13 +365,13 @@ class TestDatasetServiceCreateDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.permission == DatasetPermissionEnum.ALL_TEAM assert result.permission == DatasetPermissionEnum.ALL_TEAM
def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers): def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session):
"""Raise error when external API template does not exist.""" """Raise error when external API template does not exist."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
external_knowledge_api_id = str(uuid4()) external_knowledge_api_id = str(uuid4())
# Act / Assert # Act / Assert
@ -381,10 +389,10 @@ class TestDatasetServiceCreateDataset:
external_knowledge_id="knowledge-123", external_knowledge_id="knowledge-123",
) )
def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session):
"""Raise error when external knowledge id is missing for external dataset creation.""" """Raise error when external knowledge id is missing for external dataset creation."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
external_knowledge_api_id = str(uuid4()) external_knowledge_api_id = str(uuid4())
# Act / Assert # Act / Assert
@ -406,10 +414,10 @@ class TestDatasetServiceCreateDataset:
class TestDatasetServiceCreateRagPipelineDataset: class TestDatasetServiceCreateRagPipelineDataset:
"""Integration coverage for DatasetService.create_empty_rag_pipeline_dataset.""" """Integration coverage for DatasetService.create_empty_rag_pipeline_dataset."""
def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers): def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session):
"""Create rag-pipeline dataset and pipeline rows when a name is provided.""" """Create rag-pipeline dataset and pipeline rows when a name is provided."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity( entity = RagPipelineDatasetCreateEntity(
name="RAG Pipeline Dataset", name="RAG Pipeline Dataset",
@ -425,8 +433,8 @@ class TestDatasetServiceCreateRagPipelineDataset:
) )
# Assert # Assert
created_dataset = db.session.get(Dataset, result.id) created_dataset = db_session_with_containers.get(Dataset, result.id)
created_pipeline = db.session.get(Pipeline, result.pipeline_id) created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert created_dataset is not None assert created_dataset is not None
assert created_dataset.name == entity.name assert created_dataset.name == entity.name
assert created_dataset.runtime_mode == "rag_pipeline" assert created_dataset.runtime_mode == "rag_pipeline"
@ -436,10 +444,10 @@ class TestDatasetServiceCreateRagPipelineDataset:
assert created_pipeline.name == entity.name assert created_pipeline.name == entity.name
assert created_pipeline.created_by == account.id assert created_pipeline.created_by == account.id
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers): def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session):
"""Create rag-pipeline dataset with generated incremental name when input name is empty.""" """Create rag-pipeline dataset with generated incremental name when input name is empty."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
generated_name = "Untitled 1" generated_name = "Untitled 1"
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity( entity = RagPipelineDatasetCreateEntity(
@ -460,25 +468,26 @@ class TestDatasetServiceCreateRagPipelineDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
created_pipeline = db.session.get(Pipeline, result.pipeline_id) created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert result.name == generated_name assert result.name == generated_name
assert created_pipeline is not None assert created_pipeline is not None
assert created_pipeline.name == generated_name assert created_pipeline.name == generated_name
mock_generate_name.assert_called_once() mock_generate_name.assert_called_once()
def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers): def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session):
"""Raise duplicate-name error when rag-pipeline dataset name already exists.""" """Raise duplicate-name error when rag-pipeline dataset name already exists."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
duplicate_name = "Duplicate RAG Dataset" duplicate_name = "Duplicate RAG Dataset"
DatasetServiceIntegrationDataFactory.create_dataset( DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name=duplicate_name, name=duplicate_name,
indexing_technique=None, indexing_technique=None,
) )
db.session.commit() db_session_with_containers.commit()
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity( entity = RagPipelineDatasetCreateEntity(
name=duplicate_name, name=duplicate_name,
@ -496,10 +505,10 @@ class TestDatasetServiceCreateRagPipelineDataset:
tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity
) )
def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers): def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session):
"""Persist canonical custom permission for rag-pipeline dataset creation.""" """Persist canonical custom permission for rag-pipeline dataset creation."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity( entity = RagPipelineDatasetCreateEntity(
name="Custom Permission RAG Dataset", name="Custom Permission RAG Dataset",
@ -515,13 +524,13 @@ class TestDatasetServiceCreateRagPipelineDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.permission == DatasetPermissionEnum.ALL_TEAM assert result.permission == DatasetPermissionEnum.ALL_TEAM
def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers): def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session):
"""Persist icon metadata when creating rag-pipeline dataset.""" """Persist icon metadata when creating rag-pipeline dataset."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
icon_info = IconInfo( icon_info = IconInfo(
icon="📚", icon="📚",
icon_background="#E8F5E9", icon_background="#E8F5E9",
@ -542,23 +551,25 @@ class TestDatasetServiceCreateRagPipelineDataset:
) )
# Assert # Assert
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.icon_info == icon_info.model_dump() assert result.icon_info == icon_info.model_dump()
class TestDatasetServiceUpdateAndDeleteDataset: class TestDatasetServiceUpdateAndDeleteDataset:
"""Integration coverage for SQL-backed update and delete behavior.""" """Integration coverage for SQL-backed update and delete behavior."""
def test_update_dataset_duplicate_name_error(self, db_session_with_containers): def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session):
"""Reject update when target name already exists within the same tenant.""" """Reject update when target name already exists within the same tenant."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
source_dataset = DatasetServiceIntegrationDataFactory.create_dataset( source_dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name="Source Dataset", name="Source Dataset",
) )
DatasetServiceIntegrationDataFactory.create_dataset( DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name="Existing Dataset", name="Existing Dataset",
@ -568,17 +579,20 @@ class TestDatasetServiceUpdateAndDeleteDataset:
with pytest.raises(ValueError, match="Dataset name already exists"): with pytest.raises(ValueError, match="Dataset name already exists"):
DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account)
def test_delete_dataset_with_documents_success(self, db_session_with_containers): def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session):
"""Delete a dataset that already has documents.""" """Delete a dataset that already has documents."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset( dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
indexing_technique="high_quality", indexing_technique="high_quality",
chunk_structure="text_model", chunk_structure="text_model",
) )
DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id) DatasetServiceIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, created_by=account.id
)
# Act # Act
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
@ -586,14 +600,15 @@ class TestDatasetServiceUpdateAndDeleteDataset:
# Assert # Assert
assert result is True assert result is True
assert db.session.get(Dataset, dataset.id) is None assert db_session_with_containers.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset) dataset_deleted_signal.send.assert_called_once_with(dataset)
def test_delete_empty_dataset_success(self, db_session_with_containers): def test_delete_empty_dataset_success(self, db_session_with_containers: Session):
"""Delete a dataset that has no documents and no indexing technique.""" """Delete a dataset that has no documents and no indexing technique."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset( dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
indexing_technique=None, indexing_technique=None,
@ -606,14 +621,15 @@ class TestDatasetServiceUpdateAndDeleteDataset:
# Assert # Assert
assert result is True assert result is True
assert db.session.get(Dataset, dataset.id) is None assert db_session_with_containers.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset) dataset_deleted_signal.send.assert_called_once_with(dataset)
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session):
"""Delete dataset when indexing_technique is None but doc_form path still exists.""" """Delete dataset when indexing_technique is None but doc_form path still exists."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset( dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
indexing_technique=None, indexing_technique=None,
@ -626,17 +642,17 @@ class TestDatasetServiceUpdateAndDeleteDataset:
# Assert # Assert
assert result is True assert result is True
assert db.session.get(Dataset, dataset.id) is None assert db_session_with_containers.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset) dataset_deleted_signal.send.assert_called_once_with(dataset)
class TestDatasetServiceRetrievalConfiguration: class TestDatasetServiceRetrievalConfiguration:
"""Integration coverage for retrieval configuration persistence.""" """Integration coverage for retrieval configuration persistence."""
def test_get_dataset_retrieval_configuration(self, db_session_with_containers): def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session):
"""Return retrieval configuration that is persisted in SQL.""" """Return retrieval configuration that is persisted in SQL."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
retrieval_model = { retrieval_model = {
"search_method": "semantic_search", "search_method": "semantic_search",
"top_k": 5, "top_k": 5,
@ -644,6 +660,7 @@ class TestDatasetServiceRetrievalConfiguration:
"reranking_enable": True, "reranking_enable": True,
} }
dataset = DatasetServiceIntegrationDataFactory.create_dataset( dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
retrieval_model=retrieval_model, retrieval_model=retrieval_model,
@ -658,11 +675,12 @@ class TestDatasetServiceRetrievalConfiguration:
assert result.retrieval_model["search_method"] == "semantic_search" assert result.retrieval_model["search_method"] == "semantic_search"
assert result.retrieval_model["top_k"] == 5 assert result.retrieval_model["top_k"] == 5
def test_update_dataset_retrieval_configuration(self, db_session_with_containers): def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session):
"""Persist retrieval configuration updates through DatasetService.update_dataset.""" """Persist retrieval configuration updates through DatasetService.update_dataset."""
# Arrange # Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset( dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
indexing_technique="high_quality", indexing_technique="high_quality",
@ -684,6 +702,6 @@ class TestDatasetServiceRetrievalConfiguration:
result = DatasetService.update_dataset(dataset.id, update_data, account) result = DatasetService.update_dataset(dataset.id, update_data, account)
# Assert # Assert
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert result.id == dataset.id assert result.id == dataset.id
assert dataset.retrieval_model == update_data["retrieval_model"] assert dataset.retrieval_model == update_data["retrieval_model"]

View File

@ -11,8 +11,8 @@ from unittest.mock import call, patch
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.dataset import Dataset, Document from models.dataset import Dataset, Document
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
@ -32,6 +32,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
@staticmethod @staticmethod
def create_dataset( def create_dataset(
db_session_with_containers: Session,
dataset_id: str | None = None, dataset_id: str | None = None,
tenant_id: str | None = None, tenant_id: str | None = None,
name: str = "Test Dataset", name: str = "Test Dataset",
@ -47,12 +48,13 @@ class DocumentBatchUpdateIntegrationDataFactory:
if dataset_id: if dataset_id:
dataset.id = dataset_id dataset.id = dataset_id
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
@staticmethod @staticmethod
def create_document( def create_document(
db_session_with_containers: Session,
dataset: Dataset, dataset: Dataset,
document_id: str | None = None, document_id: str | None = None,
name: str = "test_document.pdf", name: str = "test_document.pdf",
@ -89,13 +91,14 @@ class DocumentBatchUpdateIntegrationDataFactory:
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(document, key, value) setattr(document, key, value)
db.session.add(document) db_session_with_containers.add(document)
if commit: if commit:
db.session.commit() db_session_with_containers.commit()
return document return document
@staticmethod @staticmethod
def create_multiple_documents( def create_multiple_documents(
db_session_with_containers: Session,
dataset: Dataset, dataset: Dataset,
document_ids: list[str], document_ids: list[str],
enabled: bool = True, enabled: bool = True,
@ -106,6 +109,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
documents: list[Document] = [] documents: list[Document] = []
for index, doc_id in enumerate(document_ids, start=1): for index, doc_id in enumerate(document_ids, start=1):
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
document_id=doc_id, document_id=doc_id,
name=f"document_{doc_id}.pdf", name=f"document_{doc_id}.pdf",
@ -116,7 +120,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
commit=False, commit=False,
) )
documents.append(document) documents.append(document)
db.session.commit() db_session_with_containers.commit()
return documents return documents
@staticmethod @staticmethod
@ -173,13 +177,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
assert document.archived_at is None assert document.archived_at is None
assert document.archived_by is None assert document.archived_by is None
def test_batch_update_enable_documents_success(self, db_session_with_containers, patched_dependencies): def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Enable disabled documents and trigger indexing side effects.""" """Enable disabled documents and trigger indexing side effects."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document_ids = [str(uuid4()), str(uuid4())] document_ids = [str(uuid4()), str(uuid4())]
disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
db_session_with_containers,
dataset=dataset, dataset=dataset,
document_ids=document_ids, document_ids=document_ids,
enabled=False, enabled=False,
@ -192,7 +197,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Assert # Assert
for document in disabled_docs: for document in disabled_docs:
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_enabled(document, FIXED_TIME) self._assert_document_enabled(document, FIXED_TIME)
expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
@ -203,13 +208,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls) patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls)
def test_batch_update_enable_already_enabled_document_skipped( def test_batch_update_enable_already_enabled_document_skipped(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Skip enable operation for already-enabled documents.""" """Skip enable operation for already-enabled documents."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True) document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True
)
# Act # Act
DocumentService.batch_update_document_status( DocumentService.batch_update_document_status(
@ -220,18 +227,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.enabled is True assert document.enabled is True
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_disable_documents_success(self, db_session_with_containers, patched_dependencies): def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Disable completed documents and trigger remove-index tasks.""" """Disable completed documents and trigger remove-index tasks."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document_ids = [str(uuid4()), str(uuid4())] document_ids = [str(uuid4()), str(uuid4())]
enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
db_session_with_containers,
dataset=dataset, dataset=dataset,
document_ids=document_ids, document_ids=document_ids,
enabled=True, enabled=True,
@ -248,7 +256,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Assert # Assert
for document in enabled_docs: for document in enabled_docs:
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_disabled(document, user.id, FIXED_TIME) self._assert_document_disabled(document, user.id, FIXED_TIME)
expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
@ -259,13 +267,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls) patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls)
def test_batch_update_disable_already_disabled_document_skipped( def test_batch_update_disable_already_disabled_document_skipped(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Skip disable operation for already-disabled documents.""" """Skip disable operation for already-disabled documents."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
enabled=False, enabled=False,
indexing_status="completed", indexing_status="completed",
@ -281,17 +290,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(disabled_doc) db_session_with_containers.refresh(disabled_doc)
assert disabled_doc.enabled is False assert disabled_doc.enabled is False
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["remove_task"].delay.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_disable_non_completed_document_error(self, db_session_with_containers, patched_dependencies): def test_batch_update_disable_non_completed_document_error(
self, db_session_with_containers: Session, patched_dependencies
):
"""Raise error when disabling a non-completed document.""" """Raise error when disabling a non-completed document."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
enabled=True, enabled=True,
indexing_status="indexing", indexing_status="indexing",
@ -307,13 +319,13 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
user=user, user=user,
) )
def test_batch_update_archive_documents_success(self, db_session_with_containers, patched_dependencies): def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Archive enabled documents and trigger remove-index task.""" """Archive enabled documents and trigger remove-index task."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=False db_session_with_containers, dataset=dataset, enabled=True, archived=False
) )
# Act # Act
@ -325,21 +337,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_archived(document, user.id, FIXED_TIME) self._assert_document_archived(document, user.id, FIXED_TIME)
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1)
patched_dependencies["remove_task"].delay.assert_called_once_with(document.id) patched_dependencies["remove_task"].delay.assert_called_once_with(document.id)
def test_batch_update_archive_already_archived_document_skipped( def test_batch_update_archive_already_archived_document_skipped(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Skip archive operation for already-archived documents.""" """Skip archive operation for already-archived documents."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=True db_session_with_containers, dataset=dataset, enabled=True, archived=True
) )
# Act # Act
@ -351,20 +363,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.archived is True assert document.archived is True
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["remove_task"].delay.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_archive_disabled_document_no_index_removal( def test_batch_update_archive_disabled_document_no_index_removal(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Archive disabled document without index-removal side effects.""" """Archive disabled document without index-removal side effects."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=False, archived=False db_session_with_containers, dataset=dataset, enabled=False, archived=False
) )
# Act # Act
@ -376,18 +388,18 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_archived(document, user.id, FIXED_TIME) self._assert_document_archived(document, user.id, FIXED_TIME)
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["remove_task"].delay.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_unarchive_documents_success(self, db_session_with_containers, patched_dependencies): def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Unarchive enabled documents and trigger add-index task.""" """Unarchive enabled documents and trigger add-index task."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=True db_session_with_containers, dataset=dataset, enabled=True, archived=True
) )
# Act # Act
@ -399,7 +411,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_unarchived(document) self._assert_document_unarchived(document)
assert document.updated_at == FIXED_TIME assert document.updated_at == FIXED_TIME
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
@ -407,14 +419,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["add_task"].delay.assert_called_once_with(document.id) patched_dependencies["add_task"].delay.assert_called_once_with(document.id)
def test_batch_update_unarchive_already_unarchived_document_skipped( def test_batch_update_unarchive_already_unarchived_document_skipped(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Skip unarchive operation for already-unarchived documents.""" """Skip unarchive operation for already-unarchived documents."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=False db_session_with_containers, dataset=dataset, enabled=True, archived=False
) )
# Act # Act
@ -426,20 +438,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.archived is False assert document.archived is False
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_unarchive_disabled_document_no_index_addition( def test_batch_update_unarchive_disabled_document_no_index_addition(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Unarchive disabled document without index-add side effects.""" """Unarchive disabled document without index-add side effects."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=False, archived=True db_session_with_containers, dataset=dataset, enabled=False, archived=True
) )
# Act # Act
@ -451,20 +463,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_unarchived(document) self._assert_document_unarchived(document)
assert document.updated_at == FIXED_TIME assert document.updated_at == FIXED_TIME
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_document_indexing_error_redis_cache_hit( def test_batch_update_document_indexing_error_redis_cache_hit(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Raise DocumentIndexingError when redis indicates active indexing.""" """Raise DocumentIndexingError when redis indicates active indexing."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document( document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
name="test_document.pdf", name="test_document.pdf",
enabled=True, enabled=True,
@ -483,12 +496,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
assert "test_document.pdf" in str(exc_info.value) assert "test_document.pdf" in str(exc_info.value)
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
def test_batch_update_async_task_error_handling(self, db_session_with_containers, patched_dependencies): def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies):
"""Persist DB update, then propagate async task error.""" """Persist DB update, then propagate async task error."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=False
)
patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error") patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error")
# Act / Assert # Act / Assert
@ -500,14 +515,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
user=user, user=user,
) )
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_enabled(document, FIXED_TIME) self._assert_document_enabled(document, FIXED_TIME)
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1)
def test_batch_update_empty_document_list(self, db_session_with_containers, patched_dependencies): def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies):
"""Return early when document_ids is empty.""" """Return early when document_ids is empty."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
# Act # Act
@ -520,10 +535,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["redis_client"].get.assert_not_called() patched_dependencies["redis_client"].get.assert_not_called()
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
def test_batch_update_document_not_found_skipped(self, db_session_with_containers, patched_dependencies): def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies):
"""Skip IDs that do not map to existing dataset documents.""" """Skip IDs that do not map to existing dataset documents."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
missing_document_id = str(uuid4()) missing_document_id = str(uuid4())
@ -540,18 +555,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_mixed_document_states_and_actions(self, db_session_with_containers, patched_dependencies): def test_batch_update_mixed_document_states_and_actions(
self, db_session_with_containers: Session, patched_dependencies
):
"""Process only the applicable document in a mixed-state enable batch.""" """Process only the applicable document in a mixed-state enable batch."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=False
)
enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
enabled=True, enabled=True,
position=2, position=2,
) )
archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
enabled=True, enabled=True,
archived=True, archived=True,
@ -568,9 +589,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(disabled_doc) db_session_with_containers.refresh(disabled_doc)
db.session.refresh(enabled_doc) db_session_with_containers.refresh(enabled_doc)
db.session.refresh(archived_doc) db_session_with_containers.refresh(archived_doc)
self._assert_document_enabled(disabled_doc, FIXED_TIME) self._assert_document_enabled(disabled_doc, FIXED_TIME)
assert enabled_doc.enabled is True assert enabled_doc.enabled is True
assert archived_doc.enabled is True assert archived_doc.enabled is True
@ -582,13 +603,16 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id) patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id)
def test_batch_update_large_document_list_performance(self, db_session_with_containers, patched_dependencies): def test_batch_update_large_document_list_performance(
self, db_session_with_containers: Session, patched_dependencies
):
"""Handle large document lists with consistent updates and side effects.""" """Handle large document lists with consistent updates and side effects."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document_ids = [str(uuid4()) for _ in range(100)] document_ids = [str(uuid4()) for _ in range(100)]
documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
db_session_with_containers,
dataset=dataset, dataset=dataset,
document_ids=document_ids, document_ids=document_ids,
enabled=False, enabled=False,
@ -604,7 +628,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Assert # Assert
for document in documents: for document in documents:
db.session.refresh(document) db_session_with_containers.refresh(document)
self._assert_document_enabled(document, FIXED_TIME) self._assert_document_enabled(document, FIXED_TIME)
assert patched_dependencies["redis_client"].setex.call_count == len(document_ids) assert patched_dependencies["redis_client"].setex.call_count == len(document_ids)
@ -616,17 +640,26 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)
def test_batch_update_mixed_document_states_complex_scenario( def test_batch_update_mixed_document_states_complex_scenario(
self, db_session_with_containers, patched_dependencies self, db_session_with_containers: Session, patched_dependencies
): ):
"""Process a complex mixed-state batch and update only eligible records.""" """Process a complex mixed-state batch and update only eligible records."""
# Arrange # Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user() user = DocumentBatchUpdateIntegrationDataFactory.create_user()
doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(
doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=2) db_session_with_containers, dataset=dataset, enabled=False
doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=3) )
doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=4) doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True, position=2
)
doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True, position=3
)
doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True, position=4
)
doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document( doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset, dataset=dataset,
enabled=True, enabled=True,
archived=True, archived=True,
@ -645,11 +678,11 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
) )
# Assert # Assert
db.session.refresh(doc1) db_session_with_containers.refresh(doc1)
db.session.refresh(doc2) db_session_with_containers.refresh(doc2)
db.session.refresh(doc3) db_session_with_containers.refresh(doc3)
db.session.refresh(doc4) db_session_with_containers.refresh(doc4)
db.session.refresh(doc5) db_session_with_containers.refresh(doc5)
self._assert_document_enabled(doc1, FIXED_TIME) self._assert_document_enabled(doc1, FIXED_TIME)
assert doc2.enabled is True assert doc2.enabled is True
assert doc3.enabled is True assert doc3.enabled is True

View File

@ -10,7 +10,8 @@ Tests the retrieval of document segments with pagination and filtering:
from uuid import uuid4 from uuid import uuid4
from extensions.ext_database import db from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from services.dataset_service import SegmentService from services.dataset_service import SegmentService
@ -23,6 +24,7 @@ class SegmentServiceTestDataFactory:
@staticmethod @staticmethod
def create_account_with_tenant( def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.OWNER, role: TenantAccountRole = TenantAccountRole.OWNER,
tenant: Tenant | None = None, tenant: Tenant | None = None,
) -> tuple[Account, Tenant]: ) -> tuple[Account, Tenant]:
@ -33,13 +35,13 @@ class SegmentServiceTestDataFactory:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
if tenant is None: if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
join = TenantAccountJoin( join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -47,14 +49,14 @@ class SegmentServiceTestDataFactory:
role=role, role=role,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
@staticmethod @staticmethod
def create_dataset(tenant_id: str, created_by: str) -> Dataset: def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset:
"""Create a real dataset.""" """Create a real dataset."""
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -67,12 +69,14 @@ class SegmentServiceTestDataFactory:
provider="vendor", provider="vendor",
retrieval_model={"top_k": 2}, retrieval_model={"top_k": 2},
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
@staticmethod @staticmethod
def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document: def create_document(
db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str
) -> Document:
"""Create a real document.""" """Create a real document."""
document = Document( document = Document(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -84,12 +88,13 @@ class SegmentServiceTestDataFactory:
created_from="api", created_from="api",
created_by=created_by, created_by=created_by,
) )
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
return document return document
@staticmethod @staticmethod
def create_segment( def create_segment(
db_session_with_containers: Session,
tenant_id: str, tenant_id: str,
dataset_id: str, dataset_id: str,
document_id: str, document_id: str,
@ -112,8 +117,8 @@ class SegmentServiceTestDataFactory:
tokens=tokens, tokens=tokens,
created_by=created_by, created_by=created_by,
) )
db.session.add(segment) db_session_with_containers.add(segment)
db.session.commit() db_session_with_containers.commit()
return segment return segment
@ -130,7 +135,7 @@ class TestSegmentServiceGetSegments:
- Combined filters - Combined filters
""" """
def test_get_segments_basic_pagination(self, db_session_with_containers): def test_get_segments_basic_pagination(self, db_session_with_containers: Session):
""" """
Test basic pagination functionality. Test basic pagination functionality.
@ -140,11 +145,14 @@ class TestSegmentServiceGetSegments:
- Returns segments and total count - Returns segments and total count
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
segment1 = SegmentServiceTestDataFactory.create_segment( segment1 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -153,6 +161,7 @@ class TestSegmentServiceGetSegments:
content="First segment", content="First segment",
) )
segment2 = SegmentServiceTestDataFactory.create_segment( segment2 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -170,7 +179,7 @@ class TestSegmentServiceGetSegments:
assert items[0].id == segment1.id assert items[0].id == segment1.id
assert items[1].id == segment2.id assert items[1].id == segment2.id
def test_get_segments_with_status_filter(self, db_session_with_containers): def test_get_segments_with_status_filter(self, db_session_with_containers: Session):
""" """
Test filtering by status list. Test filtering by status list.
@ -179,11 +188,14 @@ class TestSegmentServiceGetSegments:
- Only segments with matching status are returned - Only segments with matching status are returned
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -192,6 +204,7 @@ class TestSegmentServiceGetSegments:
status="completed", status="completed",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -200,6 +213,7 @@ class TestSegmentServiceGetSegments:
status="indexing", status="indexing",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -219,7 +233,7 @@ class TestSegmentServiceGetSegments:
statuses = {item.status for item in items} statuses = {item.status for item in items}
assert statuses == {"completed", "indexing"} assert statuses == {"completed", "indexing"}
def test_get_segments_with_empty_status_list(self, db_session_with_containers): def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session):
""" """
Test with empty status list. Test with empty status list.
@ -228,11 +242,14 @@ class TestSegmentServiceGetSegments:
- No status filter is applied to avoid WHERE false condition - No status filter is applied to avoid WHERE false condition
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -241,6 +258,7 @@ class TestSegmentServiceGetSegments:
status="completed", status="completed",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -256,7 +274,7 @@ class TestSegmentServiceGetSegments:
assert len(items) == 2 assert len(items) == 2
assert total == 2 assert total == 2
def test_get_segments_with_keyword_search(self, db_session_with_containers): def test_get_segments_with_keyword_search(self, db_session_with_containers: Session):
""" """
Test keyword search functionality. Test keyword search functionality.
@ -265,11 +283,14 @@ class TestSegmentServiceGetSegments:
- Search pattern includes wildcards (%keyword%) - Search pattern includes wildcards (%keyword%)
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -278,6 +299,7 @@ class TestSegmentServiceGetSegments:
content="This contains search term in the middle", content="This contains search term in the middle",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -294,7 +316,7 @@ class TestSegmentServiceGetSegments:
assert total == 1 assert total == 1
assert "search term" in items[0].content assert "search term" in items[0].content
def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers): def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session):
""" """
Test ordering by position and id. Test ordering by position and id.
@ -304,12 +326,15 @@ class TestSegmentServiceGetSegments:
- This prevents duplicate data across pages when positions are not unique - This prevents duplicate data across pages when positions are not unique
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
# Create segments with different positions # Create segments with different positions
seg_pos2 = SegmentServiceTestDataFactory.create_segment( seg_pos2 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -318,6 +343,7 @@ class TestSegmentServiceGetSegments:
content="Position 2", content="Position 2",
) )
seg_pos1 = SegmentServiceTestDataFactory.create_segment( seg_pos1 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -326,6 +352,7 @@ class TestSegmentServiceGetSegments:
content="Position 1", content="Position 1",
) )
seg_pos3 = SegmentServiceTestDataFactory.create_segment( seg_pos3 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -344,7 +371,7 @@ class TestSegmentServiceGetSegments:
assert items[1].id == seg_pos2.id assert items[1].id == seg_pos2.id
assert items[2].id == seg_pos3.id assert items[2].id == seg_pos3.id
def test_get_segments_empty_results(self, db_session_with_containers): def test_get_segments_empty_results(self, db_session_with_containers: Session):
""" """
Test when no segments match the criteria. Test when no segments match the criteria.
@ -353,7 +380,7 @@ class TestSegmentServiceGetSegments:
- Total count is 0 - Total count is 0
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
non_existent_doc_id = str(uuid4()) non_existent_doc_id = str(uuid4())
# Act # Act
@ -363,7 +390,7 @@ class TestSegmentServiceGetSegments:
assert items == [] assert items == []
assert total == 0 assert total == 0
def test_get_segments_combined_filters(self, db_session_with_containers): def test_get_segments_combined_filters(self, db_session_with_containers: Session):
""" """
Test with multiple filters combined. Test with multiple filters combined.
@ -372,12 +399,15 @@ class TestSegmentServiceGetSegments:
- Status list and keyword search both applied - Status list and keyword search both applied
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
# Create segments with various statuses and content # Create segments with various statuses and content
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -387,6 +417,7 @@ class TestSegmentServiceGetSegments:
content="This is important information", content="This is important information",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -396,6 +427,7 @@ class TestSegmentServiceGetSegments:
content="This is also important", content="This is also important",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -421,7 +453,7 @@ class TestSegmentServiceGetSegments:
assert items[0].status == "completed" assert items[0].status == "completed"
assert "important" in items[0].content assert "important" in items[0].content
def test_get_segments_with_none_status_list(self, db_session_with_containers): def test_get_segments_with_none_status_list(self, db_session_with_containers: Session):
""" """
Test with None status list. Test with None status list.
@ -430,11 +462,14 @@ class TestSegmentServiceGetSegments:
- No status filter is applied - No status filter is applied
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -443,6 +478,7 @@ class TestSegmentServiceGetSegments:
status="completed", status="completed",
) )
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
@ -462,7 +498,7 @@ class TestSegmentServiceGetSegments:
assert len(items) == 2 assert len(items) == 2
assert total == 2 assert total == 2
def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers): def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session):
""" """
Test that max_per_page is correctly set to 100. Test that max_per_page is correctly set to 100.
@ -471,13 +507,16 @@ class TestSegmentServiceGetSegments:
- This prevents excessive page sizes - This prevents excessive page sizes
""" """
# Arrange # Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
# Create 105 segments to exceed max_per_page of 100 # Create 105 segments to exceed max_per_page of 100
for i in range(105): for i in range(105):
SegmentServiceTestDataFactory.create_segment( SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,

View File

@ -13,7 +13,8 @@ This test suite covers:
import json import json
from uuid import uuid4 from uuid import uuid4
from extensions.ext_database import db from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import ( from models.dataset import (
AppDatasetJoin, AppDatasetJoin,
@ -31,7 +32,9 @@ class DatasetRetrievalTestDataFactory:
"""Factory class for creating database-backed test data for dataset retrieval integration tests.""" """Factory class for creating database-backed test data for dataset retrieval integration tests."""
@staticmethod @staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]: def create_account_with_tenant(
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL
) -> tuple[Account, Tenant]:
"""Create an account and tenant with the specified role.""" """Create an account and tenant with the specified role."""
account = Account( account = Account(
email=f"{uuid4()}@example.com", email=f"{uuid4()}@example.com",
@ -43,8 +46,8 @@ class DatasetRetrievalTestDataFactory:
name=f"tenant-{uuid4()}", name=f"tenant-{uuid4()}",
status="normal", status="normal",
) )
db.session.add_all([account, tenant]) db_session_with_containers.add_all([account, tenant])
db.session.flush() db_session_with_containers.flush()
join = TenantAccountJoin( join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -52,14 +55,16 @@ class DatasetRetrievalTestDataFactory:
role=role, role=role,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
@staticmethod @staticmethod
def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: def create_account_in_tenant(
db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER
) -> Account:
"""Create an account and add it to an existing tenant.""" """Create an account and add it to an existing tenant."""
account = Account( account = Account(
email=f"{uuid4()}@example.com", email=f"{uuid4()}@example.com",
@ -67,8 +72,8 @@ class DatasetRetrievalTestDataFactory:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.flush() db_session_with_containers.flush()
join = TenantAccountJoin( join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -76,14 +81,15 @@ class DatasetRetrievalTestDataFactory:
role=role, role=role,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
account.current_tenant = tenant account.current_tenant = tenant
return account return account
@staticmethod @staticmethod
def create_dataset( def create_dataset(
db_session_with_containers: Session,
tenant_id: str, tenant_id: str,
created_by: str, created_by: str,
name: str = "Test Dataset", name: str = "Test Dataset",
@ -101,12 +107,14 @@ class DatasetRetrievalTestDataFactory:
provider="vendor", provider="vendor",
retrieval_model={"top_k": 2}, retrieval_model={"top_k": 2},
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
@staticmethod @staticmethod
def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission: def create_dataset_permission(
db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str
) -> DatasetPermission:
"""Create a dataset permission.""" """Create a dataset permission."""
permission = DatasetPermission( permission = DatasetPermission(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -114,12 +122,14 @@ class DatasetRetrievalTestDataFactory:
account_id=account_id, account_id=account_id,
has_permission=True, has_permission=True,
) )
db.session.add(permission) db_session_with_containers.add(permission)
db.session.commit() db_session_with_containers.commit()
return permission return permission
@staticmethod @staticmethod
def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule: def create_process_rule(
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict
) -> DatasetProcessRule:
"""Create a dataset process rule.""" """Create a dataset process rule."""
process_rule = DatasetProcessRule( process_rule = DatasetProcessRule(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -127,12 +137,14 @@ class DatasetRetrievalTestDataFactory:
mode=mode, mode=mode,
rules=json.dumps(rules), rules=json.dumps(rules),
) )
db.session.add(process_rule) db_session_with_containers.add(process_rule)
db.session.commit() db_session_with_containers.commit()
return process_rule return process_rule
@staticmethod @staticmethod
def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery: def create_dataset_query(
db_session_with_containers: Session, dataset_id: str, created_by: str, content: str
) -> DatasetQuery:
"""Create a dataset query.""" """Create a dataset query."""
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -142,23 +154,23 @@ class DatasetRetrievalTestDataFactory:
created_by_role="account", created_by_role="account",
created_by=created_by, created_by=created_by,
) )
db.session.add(dataset_query) db_session_with_containers.add(dataset_query)
db.session.commit() db_session_with_containers.commit()
return dataset_query return dataset_query
@staticmethod @staticmethod
def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin: def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin:
"""Create an app-dataset join.""" """Create an app-dataset join."""
join = AppDatasetJoin( join = AppDatasetJoin(
app_id=str(uuid4()), app_id=str(uuid4()),
dataset_id=dataset_id, dataset_id=dataset_id,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
return join return join
@staticmethod @staticmethod
def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag: def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag:
"""Create a knowledge tag and bind it to the target dataset.""" """Create a knowledge tag and bind it to the target dataset."""
tag = Tag( tag = Tag(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -166,8 +178,8 @@ class DatasetRetrievalTestDataFactory:
name=f"tag-{uuid4()}", name=f"tag-{uuid4()}",
created_by=created_by, created_by=created_by,
) )
db.session.add(tag) db_session_with_containers.add(tag)
db.session.flush() db_session_with_containers.flush()
binding = TagBinding( binding = TagBinding(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -175,8 +187,8 @@ class DatasetRetrievalTestDataFactory:
target_id=target_id, target_id=target_id,
created_by=created_by, created_by=created_by,
) )
db.session.add(binding) db_session_with_containers.add(binding)
db.session.commit() db_session_with_containers.commit()
return tag return tag
@ -195,15 +207,16 @@ class TestDatasetServiceGetDatasets:
# ==================== Basic Retrieval Tests ==================== # ==================== Basic Retrieval Tests ====================
def test_get_datasets_basic_pagination(self, db_session_with_containers): def test_get_datasets_basic_pagination(self, db_session_with_containers: Session):
"""Test basic pagination without user or filters.""" """Test basic pagination without user or filters."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1 page = 1
per_page = 20 per_page = 20
for i in range(5): for i in range(5):
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name=f"Dataset {i}", name=f"Dataset {i}",
@ -217,21 +230,23 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 5 assert len(datasets) == 5
assert total == 5 assert total == 5
def test_get_datasets_with_search(self, db_session_with_containers): def test_get_datasets_with_search(self, db_session_with_containers: Session):
"""Test get_datasets with search keyword.""" """Test get_datasets with search keyword."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1 page = 1
per_page = 20 per_page = 20
search = "test" search = "test"
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name="Test Dataset", name="Test Dataset",
permission=DatasetPermissionEnum.ALL_TEAM, permission=DatasetPermissionEnum.ALL_TEAM,
) )
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name="Another Dataset", name="Another Dataset",
@ -245,26 +260,32 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1 assert len(datasets) == 1
assert total == 1 assert total == 1
def test_get_datasets_with_tag_filtering(self, db_session_with_containers): def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session):
"""Test get_datasets with tag_ids filtering.""" """Test get_datasets with tag_ids filtering."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1 page = 1
per_page = 20 per_page = 20
dataset_1 = DatasetRetrievalTestDataFactory.create_dataset( dataset_1 = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
permission=DatasetPermissionEnum.ALL_TEAM, permission=DatasetPermissionEnum.ALL_TEAM,
) )
dataset_2 = DatasetRetrievalTestDataFactory.create_dataset( dataset_2 = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
permission=DatasetPermissionEnum.ALL_TEAM, permission=DatasetPermissionEnum.ALL_TEAM,
) )
tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id) tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id) db_session_with_containers, tenant.id, account.id, dataset_1.id
)
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(
db_session_with_containers, tenant.id, account.id, dataset_2.id
)
tag_ids = [tag_1.id, tag_2.id] tag_ids = [tag_1.id, tag_2.id]
# Act # Act
@ -274,16 +295,17 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 2 assert len(datasets) == 2
assert total == 2 assert total == 2
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers): def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session):
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1 page = 1
per_page = 20 per_page = 20
tag_ids = [] tag_ids = []
for i in range(3): for i in range(3):
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
name=f"dataset-{i}", name=f"dataset-{i}",
@ -300,19 +322,21 @@ class TestDatasetServiceGetDatasets:
# ==================== Permission-Based Filtering Tests ==================== # ==================== Permission-Based Filtering Tests ====================
def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers): def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session):
"""Test that without user, only ALL_TEAM datasets are shown.""" """Test that without user, only ALL_TEAM datasets are shown."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1 page = 1
per_page = 20 per_page = 20
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
permission=DatasetPermissionEnum.ALL_TEAM, permission=DatasetPermissionEnum.ALL_TEAM,
) )
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=account.id, created_by=account.id,
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
@ -325,15 +349,18 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1 assert len(datasets) == 1
assert total == 1 assert total == 1
def test_get_datasets_owner_with_include_all(self, db_session_with_containers): def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session):
"""Test that OWNER with include_all=True sees all datasets.""" """Test that OWNER with include_all=True sees all datasets."""
# Arrange # Arrange
owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
for i, permission in enumerate( for i, permission in enumerate(
[DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM] [DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM]
): ):
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=owner.id, created_by=owner.id,
name=f"dataset-{i}", name=f"dataset-{i}",
@ -353,12 +380,15 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 3 assert len(datasets) == 3
assert total == 3 assert total == 3
def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers): def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session):
"""Test that normal user sees ONLY_ME datasets they created.""" """Test that normal user sees ONLY_ME datasets they created."""
# Arrange # Arrange
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.NORMAL
)
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
@ -371,13 +401,18 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1 assert len(datasets) == 1
assert total == 1 assert total == 1
def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers): def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session):
"""Test that normal user sees ALL_TEAM datasets.""" """Test that normal user sees ALL_TEAM datasets."""
# Arrange # Arrange
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) db_session_with_containers, role=TenantAccountRole.NORMAL
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
)
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=owner.id, created_by=owner.id,
permission=DatasetPermissionEnum.ALL_TEAM, permission=DatasetPermissionEnum.ALL_TEAM,
@ -390,18 +425,25 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1 assert len(datasets) == 1
assert total == 1 assert total == 1
def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers): def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session):
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" """Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
# Arrange # Arrange
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) db_session_with_containers, role=TenantAccountRole.NORMAL
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
)
dataset = DatasetRetrievalTestDataFactory.create_dataset( dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=owner.id, created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM, permission=DatasetPermissionEnum.PARTIAL_TEAM,
) )
DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id) DatasetRetrievalTestDataFactory.create_dataset_permission(
db_session_with_containers, dataset.id, tenant.id, user.id
)
# Act # Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
@ -410,20 +452,25 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1 assert len(datasets) == 1
assert total == 1 assert total == 1
def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers): def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session):
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" """Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
# Arrange # Arrange
operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.DATASET_OPERATOR db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
) )
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
dataset = DatasetRetrievalTestDataFactory.create_dataset( dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=owner.id, created_by=owner.id,
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
) )
DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id) DatasetRetrievalTestDataFactory.create_dataset_permission(
db_session_with_containers, dataset.id, tenant.id, operator.id
)
# Act # Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
@ -432,14 +479,17 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1 assert len(datasets) == 1
assert total == 1 assert total == 1
def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers): def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session):
"""Test that DATASET_OPERATOR without permissions returns empty result.""" """Test that DATASET_OPERATOR without permissions returns empty result."""
# Arrange # Arrange
operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.DATASET_OPERATOR db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
) )
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
DatasetRetrievalTestDataFactory.create_dataset( DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=owner.id, created_by=owner.id,
permission=DatasetPermissionEnum.ALL_TEAM, permission=DatasetPermissionEnum.ALL_TEAM,
@ -456,11 +506,13 @@ class TestDatasetServiceGetDatasets:
class TestDatasetServiceGetDataset: class TestDatasetServiceGetDataset:
"""Comprehensive integration tests for DatasetService.get_dataset method.""" """Comprehensive integration tests for DatasetService.get_dataset method."""
def test_get_dataset_success(self, db_session_with_containers): def test_get_dataset_success(self, db_session_with_containers: Session):
"""Test successful retrieval of a single dataset.""" """Test successful retrieval of a single dataset."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
# Act # Act
result = DatasetService.get_dataset(dataset.id) result = DatasetService.get_dataset(dataset.id)
@ -469,7 +521,7 @@ class TestDatasetServiceGetDataset:
assert result is not None assert result is not None
assert result.id == dataset.id assert result.id == dataset.id
def test_get_dataset_not_found(self, db_session_with_containers): def test_get_dataset_not_found(self, db_session_with_containers: Session):
"""Test retrieval when dataset doesn't exist.""" """Test retrieval when dataset doesn't exist."""
# Arrange # Arrange
dataset_id = str(uuid4()) dataset_id = str(uuid4())
@ -484,12 +536,15 @@ class TestDatasetServiceGetDataset:
class TestDatasetServiceGetDatasetsByIds: class TestDatasetServiceGetDatasetsByIds:
"""Comprehensive integration tests for DatasetService.get_datasets_by_ids method.""" """Comprehensive integration tests for DatasetService.get_datasets_by_ids method."""
def test_get_datasets_by_ids_success(self, db_session_with_containers): def test_get_datasets_by_ids_success(self, db_session_with_containers: Session):
"""Test successful bulk retrieval of datasets by IDs.""" """Test successful bulk retrieval of datasets by IDs."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
datasets = [ datasets = [
DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3) DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
for _ in range(3)
] ]
dataset_ids = [dataset.id for dataset in datasets] dataset_ids = [dataset.id for dataset in datasets]
@ -501,7 +556,7 @@ class TestDatasetServiceGetDatasetsByIds:
assert total == 3 assert total == 3
assert all(dataset.id in dataset_ids for dataset in result_datasets) assert all(dataset.id in dataset_ids for dataset in result_datasets)
def test_get_datasets_by_ids_empty_list(self, db_session_with_containers): def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session):
"""Test get_datasets_by_ids with empty list returns empty result.""" """Test get_datasets_by_ids with empty list returns empty result."""
# Arrange # Arrange
tenant_id = str(uuid4()) tenant_id = str(uuid4())
@ -514,7 +569,7 @@ class TestDatasetServiceGetDatasetsByIds:
assert datasets == [] assert datasets == []
assert total == 0 assert total == 0
def test_get_datasets_by_ids_none_list(self, db_session_with_containers): def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session):
"""Test get_datasets_by_ids with None returns empty result.""" """Test get_datasets_by_ids with None returns empty result."""
# Arrange # Arrange
tenant_id = str(uuid4()) tenant_id = str(uuid4())
@ -530,17 +585,20 @@ class TestDatasetServiceGetDatasetsByIds:
class TestDatasetServiceGetProcessRules: class TestDatasetServiceGetProcessRules:
"""Comprehensive integration tests for DatasetService.get_process_rules method.""" """Comprehensive integration tests for DatasetService.get_process_rules method."""
def test_get_process_rules_with_existing_rule(self, db_session_with_containers): def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session):
"""Test retrieval of process rules when rule exists.""" """Test retrieval of process rules when rule exists."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
rules_data = { rules_data = {
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
"segmentation": {"delimiter": "\n", "max_tokens": 500}, "segmentation": {"delimiter": "\n", "max_tokens": 500},
} }
DatasetRetrievalTestDataFactory.create_process_rule( DatasetRetrievalTestDataFactory.create_process_rule(
db_session_with_containers,
dataset_id=dataset.id, dataset_id=dataset.id,
created_by=account.id, created_by=account.id,
mode="custom", mode="custom",
@ -554,11 +612,13 @@ class TestDatasetServiceGetProcessRules:
assert result["mode"] == "custom" assert result["mode"] == "custom"
assert result["rules"] == rules_data assert result["rules"] == rules_data
def test_get_process_rules_without_existing_rule(self, db_session_with_containers): def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session):
"""Test retrieval of process rules when no rule exists (returns defaults).""" """Test retrieval of process rules when no rule exists (returns defaults)."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
# Act # Act
result = DatasetService.get_process_rules(dataset.id) result = DatasetService.get_process_rules(dataset.id)
@ -572,16 +632,19 @@ class TestDatasetServiceGetProcessRules:
class TestDatasetServiceGetDatasetQueries: class TestDatasetServiceGetDatasetQueries:
"""Comprehensive integration tests for DatasetService.get_dataset_queries method.""" """Comprehensive integration tests for DatasetService.get_dataset_queries method."""
def test_get_dataset_queries_success(self, db_session_with_containers): def test_get_dataset_queries_success(self, db_session_with_containers: Session):
"""Test successful retrieval of dataset queries.""" """Test successful retrieval of dataset queries."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
page = 1 page = 1
per_page = 20 per_page = 20
for i in range(3): for i in range(3):
DatasetRetrievalTestDataFactory.create_dataset_query( DatasetRetrievalTestDataFactory.create_dataset_query(
db_session_with_containers,
dataset_id=dataset.id, dataset_id=dataset.id,
created_by=account.id, created_by=account.id,
content=f"query-{i}", content=f"query-{i}",
@ -595,11 +658,13 @@ class TestDatasetServiceGetDatasetQueries:
assert total == 3 assert total == 3
assert all(query.dataset_id == dataset.id for query in queries) assert all(query.dataset_id == dataset.id for query in queries)
def test_get_dataset_queries_empty_result(self, db_session_with_containers): def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session):
"""Test retrieval when no queries exist.""" """Test retrieval when no queries exist."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
page = 1 page = 1
per_page = 20 per_page = 20
@ -614,14 +679,16 @@ class TestDatasetServiceGetDatasetQueries:
class TestDatasetServiceGetRelatedApps: class TestDatasetServiceGetRelatedApps:
"""Comprehensive integration tests for DatasetService.get_related_apps method.""" """Comprehensive integration tests for DatasetService.get_related_apps method."""
def test_get_related_apps_success(self, db_session_with_containers): def test_get_related_apps_success(self, db_session_with_containers: Session):
"""Test successful retrieval of related apps.""" """Test successful retrieval of related apps."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
for _ in range(2): for _ in range(2):
DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id) DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id)
# Act # Act
result = DatasetService.get_related_apps(dataset.id) result = DatasetService.get_related_apps(dataset.id)
@ -630,11 +697,13 @@ class TestDatasetServiceGetRelatedApps:
assert len(result) == 2 assert len(result) == 2
assert all(join.dataset_id == dataset.id for join in result) assert all(join.dataset_id == dataset.id for join in result)
def test_get_related_apps_empty_result(self, db_session_with_containers): def test_get_related_apps_empty_result(self, db_session_with_containers: Session):
"""Test retrieval when no related apps exist.""" """Test retrieval when no related apps exist."""
# Arrange # Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
# Act # Act
result = DatasetService.get_related_apps(dataset.id) result = DatasetService.get_related_apps(dataset.id)

View File

@ -2,9 +2,9 @@ from unittest.mock import Mock, patch
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from sqlalchemy.orm import Session
from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
@ -15,7 +15,9 @@ class DatasetUpdateTestDataFactory:
"""Factory class for creating real test data for dataset update integration tests.""" """Factory class for creating real test data for dataset update integration tests."""
@staticmethod @staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: def create_account_with_tenant(
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER
) -> tuple[Account, Tenant]:
"""Create a real account and tenant with the given role.""" """Create a real account and tenant with the given role."""
account = Account( account = Account(
email=f"{uuid4()}@example.com", email=f"{uuid4()}@example.com",
@ -23,12 +25,12 @@ class DatasetUpdateTestDataFactory:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
tenant = Tenant(name=f"tenant-{account.id}", status="normal") tenant = Tenant(name=f"tenant-{account.id}", status="normal")
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
join = TenantAccountJoin( join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -36,14 +38,15 @@ class DatasetUpdateTestDataFactory:
role=role, role=role,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
@staticmethod @staticmethod
def create_dataset( def create_dataset(
db_session_with_containers: Session,
tenant_id: str, tenant_id: str,
created_by: str, created_by: str,
provider: str = "vendor", provider: str = "vendor",
@ -71,12 +74,13 @@ class DatasetUpdateTestDataFactory:
embedding_model=embedding_model, embedding_model=embedding_model,
collection_binding_id=collection_binding_id, collection_binding_id=collection_binding_id,
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
@staticmethod @staticmethod
def create_external_binding( def create_external_binding(
db_session_with_containers: Session,
tenant_id: str, tenant_id: str,
dataset_id: str, dataset_id: str,
created_by: str, created_by: str,
@ -93,8 +97,8 @@ class DatasetUpdateTestDataFactory:
external_knowledge_id=external_knowledge_id, external_knowledge_id=external_knowledge_id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
) )
db.session.add(binding) db_session_with_containers.add(binding)
db.session.commit() db_session_with_containers.commit()
return binding return binding
@ -112,10 +116,11 @@ class TestDatasetServiceUpdateDataset:
# ==================== External Dataset Tests ==================== # ==================== External Dataset Tests ====================
def test_update_external_dataset_success(self, db_session_with_containers): def test_update_external_dataset_success(self, db_session_with_containers: Session):
"""Test successful update of external dataset.""" """Test successful update of external dataset."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="external", provider="external",
@ -124,12 +129,13 @@ class TestDatasetServiceUpdateDataset:
retrieval_model="old_model", retrieval_model="old_model",
) )
binding = DatasetUpdateTestDataFactory.create_external_binding( binding = DatasetUpdateTestDataFactory.create_external_binding(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
created_by=user.id, created_by=user.id,
) )
binding_id = binding.id binding_id = binding.id
db.session.expunge(binding) db_session_with_containers.expunge(binding)
update_data = { update_data = {
"name": "new_name", "name": "new_name",
@ -142,8 +148,8 @@ class TestDatasetServiceUpdateDataset:
result = DatasetService.update_dataset(dataset.id, update_data, user) result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
assert dataset.name == "new_name" assert dataset.name == "new_name"
assert dataset.description == "new_description" assert dataset.description == "new_description"
@ -153,15 +159,17 @@ class TestDatasetServiceUpdateDataset:
assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
assert result.id == dataset.id assert result.id == dataset.id
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session):
"""Test error when external knowledge id is missing.""" """Test error when external knowledge id is missing."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="external", provider="external",
) )
DatasetUpdateTestDataFactory.create_external_binding( DatasetUpdateTestDataFactory.create_external_binding(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
created_by=user.id, created_by=user.id,
@ -173,17 +181,19 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset.id, update_data, user) DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge id is required" in str(context.value) assert "External knowledge id is required" in str(context.value)
db.session.rollback() db_session_with_containers.rollback()
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers): def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session):
"""Test error when external knowledge api id is missing.""" """Test error when external knowledge api id is missing."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="external", provider="external",
) )
DatasetUpdateTestDataFactory.create_external_binding( DatasetUpdateTestDataFactory.create_external_binding(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
dataset_id=dataset.id, dataset_id=dataset.id,
created_by=user.id, created_by=user.id,
@ -195,12 +205,13 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset.id, update_data, user) DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge api id is required" in str(context.value) assert "External knowledge api id is required" in str(context.value)
db.session.rollback() db_session_with_containers.rollback()
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers): def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session):
"""Test error when external knowledge binding is not found.""" """Test error when external knowledge binding is not found."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="external", provider="external",
@ -216,15 +227,16 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset.id, update_data, user) DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge binding not found" in str(context.value) assert "External knowledge binding not found" in str(context.value)
db.session.rollback() db_session_with_containers.rollback()
# ==================== Internal Dataset Basic Tests ==================== # ==================== Internal Dataset Basic Tests ====================
def test_update_internal_dataset_basic_success(self, db_session_with_containers): def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session):
"""Test successful update of internal dataset with basic fields.""" """Test successful update of internal dataset with basic fields."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4()) existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",
@ -244,7 +256,7 @@ class TestDatasetServiceUpdateDataset:
} }
result = DatasetService.update_dataset(dataset.id, update_data, user) result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name" assert dataset.name == "new_name"
assert dataset.description == "new_description" assert dataset.description == "new_description"
@ -254,11 +266,12 @@ class TestDatasetServiceUpdateDataset:
assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model == "text-embedding-ada-002"
assert result.id == dataset.id assert result.id == dataset.id
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers): def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session):
"""Test that None values are filtered out except for description field.""" """Test that None values are filtered out except for description field."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4()) existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",
@ -278,7 +291,7 @@ class TestDatasetServiceUpdateDataset:
} }
result = DatasetService.update_dataset(dataset.id, update_data, user) result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name" assert dataset.name == "new_name"
assert dataset.description is None assert dataset.description is None
@ -289,11 +302,12 @@ class TestDatasetServiceUpdateDataset:
# ==================== Indexing Technique Switch Tests ==================== # ==================== Indexing Technique Switch Tests ====================
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers): def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session):
"""Test updating internal dataset indexing technique to economy.""" """Test updating internal dataset indexing technique to economy."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4()) existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",
@ -312,7 +326,7 @@ class TestDatasetServiceUpdateDataset:
result = DatasetService.update_dataset(dataset.id, update_data, user) result = DatasetService.update_dataset(dataset.id, update_data, user)
mock_task.delay.assert_called_once_with(dataset.id, "remove") mock_task.delay.assert_called_once_with(dataset.id, "remove")
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "economy" assert dataset.indexing_technique == "economy"
assert dataset.embedding_model is None assert dataset.embedding_model is None
assert dataset.embedding_model_provider is None assert dataset.embedding_model_provider is None
@ -320,10 +334,11 @@ class TestDatasetServiceUpdateDataset:
assert dataset.retrieval_model == "new_model" assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id assert result.id == dataset.id
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers): def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session):
"""Test updating internal dataset indexing technique to high_quality.""" """Test updating internal dataset indexing technique to high_quality."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",
@ -366,7 +381,7 @@ class TestDatasetServiceUpdateDataset:
mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002")
mock_task.delay.assert_called_once_with(dataset.id, "add") mock_task.delay.assert_called_once_with(dataset.id, "add")
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "high_quality" assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model_provider == "openai"
@ -380,9 +395,10 @@ class TestDatasetServiceUpdateDataset:
self, db_session_with_containers self, db_session_with_containers
): ):
"""Test preserving embedding settings when indexing technique remains unchanged.""" """Test preserving embedding settings when indexing technique remains unchanged."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4()) existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",
@ -399,7 +415,7 @@ class TestDatasetServiceUpdateDataset:
} }
result = DatasetService.update_dataset(dataset.id, update_data, user) result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name" assert dataset.name == "new_name"
assert dataset.indexing_technique == "high_quality" assert dataset.indexing_technique == "high_quality"
@ -409,11 +425,12 @@ class TestDatasetServiceUpdateDataset:
assert dataset.retrieval_model == "new_model" assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id assert result.id == dataset.id
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers): def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session):
"""Test updating internal dataset with new embedding model.""" """Test updating internal dataset with new embedding model."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4()) existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",
@ -465,7 +482,7 @@ class TestDatasetServiceUpdateDataset:
regenerate_vectors_only=True, regenerate_vectors_only=True,
) )
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.embedding_model == "text-embedding-3-small" assert dataset.embedding_model == "text-embedding-3-small"
assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == binding.id assert dataset.collection_binding_id == binding.id
@ -474,9 +491,9 @@ class TestDatasetServiceUpdateDataset:
# ==================== Error Handling Tests ==================== # ==================== Error Handling Tests ====================
def test_update_dataset_not_found_error(self, db_session_with_containers): def test_update_dataset_not_found_error(self, db_session_with_containers: Session):
"""Test error when dataset is not found.""" """Test error when dataset is not found."""
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant() user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
update_data = {"name": "new_name"} update_data = {"name": "new_name"}
with pytest.raises(ValueError) as context: with pytest.raises(ValueError) as context:
@ -484,11 +501,16 @@ class TestDatasetServiceUpdateDataset:
assert "Dataset not found" in str(context.value) assert "Dataset not found" in str(context.value)
def test_update_dataset_permission_error(self, db_session_with_containers): def test_update_dataset_permission_error(self, db_session_with_containers: Session):
"""Test error when user doesn't have permission.""" """Test error when user doesn't have permission."""
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) db_session_with_containers, role=TenantAccountRole.OWNER
)
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.NORMAL
)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=owner.id, created_by=owner.id,
provider="vendor", provider="vendor",
@ -500,10 +522,11 @@ class TestDatasetServiceUpdateDataset:
with pytest.raises(NoPermissionError): with pytest.raises(NoPermissionError):
DatasetService.update_dataset(dataset.id, update_data, outsider) DatasetService.update_dataset(dataset.id, update_data, outsider)
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers): def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session):
"""Test error when embedding model is not available.""" """Test error when embedding model is not available."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset( dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id, tenant_id=tenant.id,
created_by=user.id, created_by=user.id,
provider="vendor", provider="vendor",

View File

@ -5,6 +5,7 @@ from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import Engine from sqlalchemy import Engine
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
@ -19,7 +20,7 @@ class TestFileService:
"""Integration tests for FileService using testcontainers.""" """Integration tests for FileService using testcontainers."""
@pytest.fixture @pytest.fixture
def engine(self, db_session_with_containers): def engine(self, db_session_with_containers: Session):
bind = db_session_with_containers.get_bind() bind = db_session_with_containers.get_bind()
assert isinstance(bind, Engine) assert isinstance(bind, Engine)
return bind return bind
@ -46,7 +47,7 @@ class TestFileService:
"extract_processor": mock_extract_processor, "extract_processor": mock_extract_processor,
} }
def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account for testing. Helper method to create a test account for testing.
@ -67,18 +68,16 @@ class TestFileService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
@ -89,15 +88,15 @@ class TestFileService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account return account
def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test end user for testing. Helper method to create a test end user for testing.
@ -118,14 +117,14 @@ class TestFileService:
session_id=fake.uuid4(), session_id=fake.uuid4(),
) )
from extensions.ext_database import db db_session_with_containers.add(end_user)
db_session_with_containers.commit()
db.session.add(end_user)
db.session.commit()
return end_user return end_user
def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account): def _create_test_upload_file(
self, db_session_with_containers: Session, mock_external_service_dependencies, account
):
""" """
Helper method to create a test upload file for testing. Helper method to create a test upload file for testing.
@ -155,15 +154,13 @@ class TestFileService:
source_url="", source_url="",
) )
from extensions.ext_database import db db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
db.session.add(upload_file)
db.session.commit()
return upload_file return upload_file
# Test upload_file method # Test upload_file method
def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies):
""" """
Test successful file upload with valid parameters. Test successful file upload with valid parameters.
""" """
@ -196,7 +193,9 @@ class TestFileService:
assert upload_file.id is not None assert upload_file.id is not None
def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_with_end_user(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file upload with end user instead of account. Test file upload with end user instead of account.
""" """
@ -219,7 +218,7 @@ class TestFileService:
assert upload_file.created_by_role == CreatorUserRole.END_USER assert upload_file.created_by_role == CreatorUserRole.END_USER
def test_upload_file_with_datasets_source( def test_upload_file_with_datasets_source(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with datasets source parameter. Test file upload with datasets source parameter.
@ -244,7 +243,7 @@ class TestFileService:
assert upload_file.source_url == "https://example.com/source" assert upload_file.source_url == "https://example.com/source"
def test_upload_file_invalid_filename_characters( def test_upload_file_invalid_filename_characters(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with invalid filename characters. Test file upload with invalid filename characters.
@ -265,7 +264,7 @@ class TestFileService:
) )
def test_upload_file_filename_too_long( def test_upload_file_filename_too_long(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with filename that exceeds length limit. Test file upload with filename that exceeds length limit.
@ -295,7 +294,7 @@ class TestFileService:
assert len(base_name) <= 200 assert len(base_name) <= 200
def test_upload_file_datasets_unsupported_type( def test_upload_file_datasets_unsupported_type(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload for datasets with unsupported file type. Test file upload for datasets with unsupported file type.
@ -316,7 +315,9 @@ class TestFileService:
source="datasets", source="datasets",
) )
def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_too_large(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file upload with file size exceeding limit. Test file upload with file size exceeding limit.
""" """
@ -338,7 +339,7 @@ class TestFileService:
# Test is_file_size_within_limit method # Test is_file_size_within_limit method
def test_is_file_size_within_limit_image_success( def test_is_file_size_within_limit_image_success(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file size check for image files within limit. Test file size check for image files within limit.
@ -351,7 +352,7 @@ class TestFileService:
assert result is True assert result is True
def test_is_file_size_within_limit_video_success( def test_is_file_size_within_limit_video_success(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file size check for video files within limit. Test file size check for video files within limit.
@ -364,7 +365,7 @@ class TestFileService:
assert result is True assert result is True
def test_is_file_size_within_limit_audio_success( def test_is_file_size_within_limit_audio_success(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file size check for audio files within limit. Test file size check for audio files within limit.
@ -377,7 +378,7 @@ class TestFileService:
assert result is True assert result is True
def test_is_file_size_within_limit_document_success( def test_is_file_size_within_limit_document_success(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file size check for document files within limit. Test file size check for document files within limit.
@ -390,7 +391,7 @@ class TestFileService:
assert result is True assert result is True
def test_is_file_size_within_limit_image_exceeded( def test_is_file_size_within_limit_image_exceeded(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file size check for image files exceeding limit. Test file size check for image files exceeding limit.
@ -403,7 +404,7 @@ class TestFileService:
assert result is False assert result is False
def test_is_file_size_within_limit_unknown_extension( def test_is_file_size_within_limit_unknown_extension(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file size check for unknown file extension. Test file size check for unknown file extension.
@ -416,7 +417,7 @@ class TestFileService:
assert result is True assert result is True
# Test upload_text method # Test upload_text method
def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies):
""" """
Test successful text upload. Test successful text upload.
""" """
@ -447,7 +448,9 @@ class TestFileService:
# Verify storage was called # Verify storage was called
mock_external_service_dependencies["storage"].save.assert_called_once() mock_external_service_dependencies["storage"].save.assert_called_once()
def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_text_name_too_long(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test text upload with name that exceeds length limit. Test text upload with name that exceeds length limit.
""" """
@ -472,7 +475,9 @@ class TestFileService:
assert upload_file.name == "a" * 200 assert upload_file.name == "a" * 200
# Test get_file_preview method # Test get_file_preview method
def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_get_file_preview_success(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test successful file preview generation. Test successful file preview generation.
""" """
@ -484,9 +489,8 @@ class TestFileService:
# Update file to have document extension # Update file to have document extension
upload_file.extension = "pdf" upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
result = FileService(engine).get_file_preview(file_id=upload_file.id) result = FileService(engine).get_file_preview(file_id=upload_file.id)
@ -494,7 +498,7 @@ class TestFileService:
mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once()
def test_get_file_preview_file_not_found( def test_get_file_preview_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file preview with non-existent file. Test file preview with non-existent file.
@ -506,7 +510,7 @@ class TestFileService:
FileService(engine).get_file_preview(file_id=non_existent_id) FileService(engine).get_file_preview(file_id=non_existent_id)
def test_get_file_preview_unsupported_file_type( def test_get_file_preview_unsupported_file_type(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file preview with unsupported file type. Test file preview with unsupported file type.
@ -519,15 +523,14 @@ class TestFileService:
# Update file to have non-document extension # Update file to have non-document extension
upload_file.extension = "jpg" upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
with pytest.raises(UnsupportedFileTypeError): with pytest.raises(UnsupportedFileTypeError):
FileService(engine).get_file_preview(file_id=upload_file.id) FileService(engine).get_file_preview(file_id=upload_file.id)
def test_get_file_preview_text_truncation( def test_get_file_preview_text_truncation(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file preview with text that exceeds preview limit. Test file preview with text that exceeds preview limit.
@ -540,9 +543,8 @@ class TestFileService:
# Update file to have document extension # Update file to have document extension
upload_file.extension = "pdf" upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
# Mock long text content # Mock long text content
long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT
@ -554,7 +556,9 @@ class TestFileService:
assert result == "x" * 3000 assert result == "x" * 3000
# Test get_image_preview method # Test get_image_preview method
def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_get_image_preview_success(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test successful image preview generation. Test successful image preview generation.
""" """
@ -566,9 +570,8 @@ class TestFileService:
# Update file to have image extension # Update file to have image extension
upload_file.extension = "jpg" upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
timestamp = "1234567890" timestamp = "1234567890"
nonce = "test_nonce" nonce = "test_nonce"
@ -586,7 +589,7 @@ class TestFileService:
mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once() mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once()
def test_get_image_preview_invalid_signature( def test_get_image_preview_invalid_signature(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test image preview with invalid signature. Test image preview with invalid signature.
@ -613,7 +616,7 @@ class TestFileService:
) )
def test_get_image_preview_file_not_found( def test_get_image_preview_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test image preview with non-existent file. Test image preview with non-existent file.
@ -634,7 +637,7 @@ class TestFileService:
) )
def test_get_image_preview_unsupported_file_type( def test_get_image_preview_unsupported_file_type(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test image preview with non-image file type. Test image preview with non-image file type.
@ -647,9 +650,8 @@ class TestFileService:
# Update file to have non-image extension # Update file to have non-image extension
upload_file.extension = "pdf" upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
timestamp = "1234567890" timestamp = "1234567890"
nonce = "test_nonce" nonce = "test_nonce"
@ -665,7 +667,7 @@ class TestFileService:
# Test get_file_generator_by_file_id method # Test get_file_generator_by_file_id method
def test_get_file_generator_by_file_id_success( def test_get_file_generator_by_file_id_success(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test successful file generator retrieval. Test successful file generator retrieval.
@ -692,7 +694,7 @@ class TestFileService:
mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once() mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once()
def test_get_file_generator_by_file_id_invalid_signature( def test_get_file_generator_by_file_id_invalid_signature(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file generator retrieval with invalid signature. Test file generator retrieval with invalid signature.
@ -719,7 +721,7 @@ class TestFileService:
) )
def test_get_file_generator_by_file_id_file_not_found( def test_get_file_generator_by_file_id_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file generator retrieval with non-existent file. Test file generator retrieval with non-existent file.
@ -741,7 +743,7 @@ class TestFileService:
# Test get_public_image_preview method # Test get_public_image_preview method
def test_get_public_image_preview_success( def test_get_public_image_preview_success(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test successful public image preview generation. Test successful public image preview generation.
@ -754,9 +756,8 @@ class TestFileService:
# Update file to have image extension # Update file to have image extension
upload_file.extension = "jpg" upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id) generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id)
@ -765,7 +766,7 @@ class TestFileService:
mock_external_service_dependencies["storage"].load.assert_called_once() mock_external_service_dependencies["storage"].load.assert_called_once()
def test_get_public_image_preview_file_not_found( def test_get_public_image_preview_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test public image preview with non-existent file. Test public image preview with non-existent file.
@ -777,7 +778,7 @@ class TestFileService:
FileService(engine).get_public_image_preview(file_id=non_existent_id) FileService(engine).get_public_image_preview(file_id=non_existent_id)
def test_get_public_image_preview_unsupported_file_type( def test_get_public_image_preview_unsupported_file_type(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test public image preview with non-image file type. Test public image preview with non-image file type.
@ -790,15 +791,16 @@ class TestFileService:
# Update file to have non-image extension # Update file to have non-image extension
upload_file.extension = "pdf" upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
with pytest.raises(UnsupportedFileTypeError): with pytest.raises(UnsupportedFileTypeError):
FileService(engine).get_public_image_preview(file_id=upload_file.id) FileService(engine).get_public_image_preview(file_id=upload_file.id)
# Test edge cases and boundary conditions # Test edge cases and boundary conditions
def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_empty_content(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file upload with empty content. Test file upload with empty content.
""" """
@ -820,7 +822,7 @@ class TestFileService:
assert upload_file.size == 0 assert upload_file.size == 0
def test_upload_file_special_characters_in_name( def test_upload_file_special_characters_in_name(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with special characters in filename (but valid ones). Test file upload with special characters in filename (but valid ones).
@ -843,7 +845,7 @@ class TestFileService:
assert upload_file.name == filename assert upload_file.name == filename
def test_upload_file_different_case_extensions( def test_upload_file_different_case_extensions(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with different case extensions. Test file upload with different case extensions.
@ -865,7 +867,9 @@ class TestFileService:
assert upload_file is not None assert upload_file is not None
assert upload_file.extension == "pdf" # Should be converted to lowercase assert upload_file.extension == "pdf" # Should be converted to lowercase
def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_text_empty_text(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test text upload with empty text. Test text upload with empty text.
""" """
@ -888,7 +892,9 @@ class TestFileService:
assert upload_file is not None assert upload_file is not None
assert upload_file.size == 0 assert upload_file.size == 0
def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_file_size_limits_edge_cases(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file size limits with edge case values. Test file size limits with edge case values.
""" """
@ -908,7 +914,9 @@ class TestFileService:
result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size) result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is False assert result is False
def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_with_source_url(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file upload with source URL that gets overridden by signed URL. Test file upload with source URL that gets overridden by signed URL.
""" """
@ -946,7 +954,7 @@ class TestFileService:
# Test file extension blacklist # Test file extension blacklist
def test_upload_file_blocked_extension( def test_upload_file_blocked_extension(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with blocked extension. Test file upload with blocked extension.
@ -969,7 +977,7 @@ class TestFileService:
) )
def test_upload_file_blocked_extension_case_insensitive( def test_upload_file_blocked_extension_case_insensitive(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with blocked extension (case insensitive). Test file upload with blocked extension (case insensitive).
@ -992,7 +1000,9 @@ class TestFileService:
user=account, user=account,
) )
def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_not_in_blacklist(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file upload with extension not in blacklist. Test file upload with extension not in blacklist.
""" """
@ -1016,7 +1026,9 @@ class TestFileService:
assert upload_file.name == filename assert upload_file.name == filename
assert upload_file.extension == "pdf" assert upload_file.extension == "pdf"
def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): def test_upload_file_empty_blacklist(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
""" """
Test file upload with empty blacklist (default behavior). Test file upload with empty blacklist (default behavior).
""" """
@ -1041,7 +1053,7 @@ class TestFileService:
assert upload_file.extension == "sh" assert upload_file.extension == "sh"
def test_upload_file_multiple_blocked_extensions( def test_upload_file_multiple_blocked_extensions(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with multiple blocked extensions. Test file upload with multiple blocked extensions.
@ -1066,7 +1078,7 @@ class TestFileService:
) )
def test_upload_file_no_extension_with_blacklist( def test_upload_file_no_extension_with_blacklist(
self, db_session_with_containers, engine, mock_external_service_dependencies self, db_session_with_containers: Session, engine, mock_external_service_dependencies
): ):
""" """
Test file upload with no extension when blacklist is configured. Test file upload with no extension when blacklist is configured.

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models.model import MessageFeedback from models.model import MessageFeedback
from services.app_service import AppService from services.app_service import AppService
@ -69,7 +70,7 @@ class TestMessageService:
# "current_user": mock_current_user, # "current_user": mock_current_user,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -127,11 +128,10 @@ class TestMessageService:
# mock_external_service_dependencies["current_user"].id = account_id # mock_external_service_dependencies["current_user"].id = account_id
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id # mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
def _create_test_conversation(self, app, account, fake): def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake):
""" """
Helper method to create a test conversation with all required fields. Helper method to create a test conversation with all required fields.
""" """
from extensions.ext_database import db
from models.model import Conversation from models.model import Conversation
conversation = Conversation( conversation = Conversation(
@ -153,17 +153,16 @@ class TestMessageService:
from_account_id=account.id, from_account_id=account.id,
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.flush() db_session_with_containers.flush()
return conversation return conversation
def _create_test_message(self, app, conversation, account, fake): def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake):
""" """
Helper method to create a test message with all required fields. Helper method to create a test message with all required fields.
""" """
import json import json
from extensions.ext_database import db
from models.model import Message from models.model import Message
message = Message( message = Message(
@ -192,11 +191,13 @@ class TestMessageService:
from_account_id=account.id, from_account_id=account.id,
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
return message return message
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_first_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful pagination by first ID. Test successful pagination by first ID.
""" """
@ -204,10 +205,10 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages # Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
messages = [] messages = []
for i in range(5): for i in range(5):
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
messages.append(message) messages.append(message)
# Test pagination by first ID # Test pagination by first ID
@ -228,7 +229,9 @@ class TestMessageService:
# Verify messages are in ascending order # Verify messages are in ascending order
assert result.data[0].created_at <= result.data[1].created_at assert result.data[0].created_at <= result.data[1].created_at
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_first_id_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test pagination by first ID when no user is provided. Test pagination by first ID when no user is provided.
""" """
@ -246,7 +249,7 @@ class TestMessageService:
assert result.has_more is False assert result.has_more is False
def test_pagination_by_first_id_no_conversation_id( def test_pagination_by_first_id_no_conversation_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination by first ID when no conversation ID is provided. Test pagination by first ID when no conversation ID is provided.
@ -265,7 +268,7 @@ class TestMessageService:
assert result.has_more is False assert result.has_more is False
def test_pagination_by_first_id_invalid_first_id( def test_pagination_by_first_id_invalid_first_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination by first ID with invalid first_id. Test pagination by first ID with invalid first_id.
@ -274,8 +277,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
self._create_test_message(app, conversation, account, fake) self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test pagination with invalid first_id # Test pagination with invalid first_id
with pytest.raises(FirstMessageNotExistsError): with pytest.raises(FirstMessageNotExistsError):
@ -287,7 +290,9 @@ class TestMessageService:
limit=10, limit=10,
) )
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_last_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful pagination by last ID. Test successful pagination by last ID.
""" """
@ -295,10 +300,10 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages # Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
messages = [] messages = []
for i in range(5): for i in range(5):
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
messages.append(message) messages.append(message)
# Test pagination by last ID # Test pagination by last ID
@ -319,7 +324,7 @@ class TestMessageService:
assert result.data[0].created_at >= result.data[1].created_at assert result.data[0].created_at >= result.data[1].created_at
def test_pagination_by_last_id_with_include_ids( def test_pagination_by_last_id_with_include_ids(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination by last ID with include_ids filter. Test pagination by last ID with include_ids filter.
@ -328,10 +333,10 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages # Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
messages = [] messages = []
for i in range(5): for i in range(5):
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
messages.append(message) messages.append(message)
# Test pagination with include_ids # Test pagination with include_ids
@ -347,7 +352,9 @@ class TestMessageService:
for message in result.data: for message in result.data:
assert message.id in include_ids assert message.id in include_ids
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_last_id_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test pagination by last ID when no user is provided. Test pagination by last ID when no user is provided.
""" """
@ -363,7 +370,7 @@ class TestMessageService:
assert result.has_more is False assert result.has_more is False
def test_pagination_by_last_id_invalid_last_id( def test_pagination_by_last_id_invalid_last_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination by last ID with invalid last_id. Test pagination by last ID with invalid last_id.
@ -372,8 +379,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
self._create_test_message(app, conversation, account, fake) self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test pagination with invalid last_id # Test pagination with invalid last_id
with pytest.raises(LastMessageNotExistsError): with pytest.raises(LastMessageNotExistsError):
@ -385,7 +392,7 @@ class TestMessageService:
conversation_id=conversation.id, conversation_id=conversation.id,
) )
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful creation of feedback. Test successful creation of feedback.
""" """
@ -393,8 +400,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create feedback # Create feedback
rating = "like" rating = "like"
@ -413,7 +420,7 @@ class TestMessageService:
assert feedback.from_account_id == account.id assert feedback.from_account_id == account.id
assert feedback.from_end_user_id is None assert feedback.from_end_user_id is None
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test creating feedback when no user is provided. Test creating feedback when no user is provided.
""" """
@ -421,8 +428,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test creating feedback with no user # Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"): with pytest.raises(ValueError, match="user cannot be None"):
@ -430,7 +437,9 @@ class TestMessageService:
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
) )
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies): def test_create_feedback_update_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test updating existing feedback. Test updating existing feedback.
""" """
@ -438,8 +447,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback # Create initial feedback
initial_rating = "like" initial_rating = "like"
@ -462,7 +471,9 @@ class TestMessageService:
assert updated_feedback.rating != initial_rating assert updated_feedback.rating != initial_rating
assert updated_feedback.content != initial_content assert updated_feedback.content != initial_content
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies): def test_create_feedback_delete_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test deleting existing feedback by setting rating to None. Test deleting existing feedback by setting rating to None.
""" """
@ -470,8 +481,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback # Create initial feedback
feedback = MessageService.create_feedback( feedback = MessageService.create_feedback(
@ -482,13 +493,14 @@ class TestMessageService:
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None) MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
# Verify feedback was deleted # Verify feedback was deleted
from extensions.ext_database import db
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() deleted_feedback = (
db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
)
assert deleted_feedback is None assert deleted_feedback is None
def test_create_feedback_no_rating_when_not_exists( def test_create_feedback_no_rating_when_not_exists(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creating feedback with no rating when feedback doesn't exist. Test creating feedback with no rating when feedback doesn't exist.
@ -497,8 +509,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test creating feedback with no rating when no feedback exists # Test creating feedback with no rating when no feedback exists
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"): with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
@ -506,7 +518,9 @@ class TestMessageService:
app_model=app, message_id=message.id, user=account, rating=None, content=None app_model=app, message_id=message.id, user=account, rating=None, content=None
) )
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_all_messages_feedbacks_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of all message feedbacks. Test successful retrieval of all message feedbacks.
""" """
@ -516,8 +530,8 @@ class TestMessageService:
# Create multiple conversations and messages with feedbacks # Create multiple conversations and messages with feedbacks
feedbacks = [] feedbacks = []
for i in range(3): for i in range(3):
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
feedback = MessageService.create_feedback( feedback = MessageService.create_feedback(
app_model=app, app_model=app,
@ -539,7 +553,7 @@ class TestMessageService:
assert result[i]["created_at"] >= result[i + 1]["created_at"] assert result[i]["created_at"] >= result[i + 1]["created_at"]
def test_get_all_messages_feedbacks_pagination( def test_get_all_messages_feedbacks_pagination(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination of message feedbacks. Test pagination of message feedbacks.
@ -549,8 +563,8 @@ class TestMessageService:
# Create multiple conversations and messages with feedbacks # Create multiple conversations and messages with feedbacks
for i in range(5): for i in range(5):
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
MessageService.create_feedback( MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
@ -569,7 +583,7 @@ class TestMessageService:
page_2_ids = {feedback["id"] for feedback in result_page_2} page_2_ids = {feedback["id"] for feedback in result_page_2}
assert len(page_1_ids.intersection(page_2_ids)) == 0 assert len(page_1_ids.intersection(page_2_ids)) == 0
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of message. Test successful retrieval of message.
""" """
@ -577,8 +591,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Get message # Get message
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id) retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
@ -590,7 +604,7 @@ class TestMessageService:
assert retrieved_message.from_source == "console" assert retrieved_message.from_source == "console"
assert retrieved_message.from_account_id == account.id assert retrieved_message.from_account_id == account.id
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies): def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test getting message that doesn't exist. Test getting message that doesn't exist.
""" """
@ -601,7 +615,7 @@ class TestMessageService:
with pytest.raises(MessageNotExistsError): with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4()) MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies): def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test getting message with wrong user (different account). Test getting message with wrong user (different account).
""" """
@ -609,8 +623,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create another account # Create another account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
@ -628,7 +642,7 @@ class TestMessageService:
MessageService.get_message(app_model=app, user=other_account, message_id=message.id) MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
def test_get_suggested_questions_after_answer_success( def test_get_suggested_questions_after_answer_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful generation of suggested questions after answer. Test successful generation of suggested questions after answer.
@ -637,8 +651,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock the LLMGenerator to return specific questions # Mock the LLMGenerator to return specific questions
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"] mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
@ -665,7 +679,7 @@ class TestMessageService:
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
def test_get_suggested_questions_after_answer_no_user( def test_get_suggested_questions_after_answer_no_user(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting suggested questions when no user is provided. Test getting suggested questions when no user is provided.
@ -674,8 +688,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test getting suggested questions with no user # Test getting suggested questions with no user
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -686,7 +700,7 @@ class TestMessageService:
) )
def test_get_suggested_questions_after_answer_disabled( def test_get_suggested_questions_after_answer_disabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting suggested questions when feature is disabled. Test getting suggested questions when feature is disabled.
@ -695,8 +709,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock the feature to be disabled # Mock the feature to be disabled
mock_external_service_dependencies[ mock_external_service_dependencies[
@ -712,7 +726,7 @@ class TestMessageService:
) )
def test_get_suggested_questions_after_answer_no_workflow( def test_get_suggested_questions_after_answer_no_workflow(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting suggested questions when no workflow exists. Test getting suggested questions when no workflow exists.
@ -721,8 +735,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock no workflow # Mock no workflow
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
@ -738,7 +752,7 @@ class TestMessageService:
assert result == [] assert result == []
def test_get_suggested_questions_after_answer_debugger_mode( def test_get_suggested_questions_after_answer_debugger_mode(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting suggested questions in debugger mode. Test getting suggested questions in debugger mode.
@ -747,8 +761,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message # Create a conversation and message
conversation = self._create_test_conversation(app, account, fake) conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(app, conversation, account, fake) message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock questions # Mock questions
mock_questions = ["Debug question 1", "Debug question 2"] mock_questions = ["Debug question 1", "Debug question 2"]

View File

@ -6,9 +6,9 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import ( from models.model import (
@ -40,25 +40,25 @@ class TestMessagesCleanServiceIntegration:
PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:"
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers): def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before and after each test to ensure isolation.""" """Clean up database before and after each test to ensure isolation."""
yield yield
# Clear all test data in correct order (respecting foreign key constraints) # Clear all test data in correct order (respecting foreign key constraints)
db.session.query(DatasetRetrieverResource).delete() db_session_with_containers.query(DatasetRetrieverResource).delete()
db.session.query(AppAnnotationHitHistory).delete() db_session_with_containers.query(AppAnnotationHitHistory).delete()
db.session.query(SavedMessage).delete() db_session_with_containers.query(SavedMessage).delete()
db.session.query(MessageFile).delete() db_session_with_containers.query(MessageFile).delete()
db.session.query(MessageAgentThought).delete() db_session_with_containers.query(MessageAgentThought).delete()
db.session.query(MessageChain).delete() db_session_with_containers.query(MessageChain).delete()
db.session.query(MessageAnnotation).delete() db_session_with_containers.query(MessageAnnotation).delete()
db.session.query(MessageFeedback).delete() db_session_with_containers.query(MessageFeedback).delete()
db.session.query(Message).delete() db_session_with_containers.query(Message).delete()
db.session.query(Conversation).delete() db_session_with_containers.query(Conversation).delete()
db.session.query(App).delete() db_session_with_containers.query(App).delete()
db.session.query(TenantAccountJoin).delete() db_session_with_containers.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete() db_session_with_containers.query(Tenant).delete()
db.session.query(Account).delete() db_session_with_containers.query(Account).delete()
db.session.commit() db_session_with_containers.commit()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup_redis(self): def cleanup_redis(self):
@ -100,7 +100,7 @@ class TestMessagesCleanServiceIntegration:
with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False):
yield yield
def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX):
"""Helper to create account and tenant.""" """Helper to create account and tenant."""
fake = Faker() fake = Faker()
@ -110,28 +110,28 @@ class TestMessagesCleanServiceIntegration:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.flush() db_session_with_containers.flush()
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
plan=str(plan), plan=str(plan),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.flush() db_session_with_containers.flush()
tenant_account_join = TenantAccountJoin( tenant_account_join = TenantAccountJoin(
tenant_id=tenant.id, tenant_id=tenant.id,
account_id=account.id, account_id=account.id,
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
) )
db.session.add(tenant_account_join) db_session_with_containers.add(tenant_account_join)
db.session.commit() db_session_with_containers.commit()
return account, tenant return account, tenant
def _create_app(self, tenant, account): def _create_app(self, db_session_with_containers: Session, tenant, account):
"""Helper to create an app.""" """Helper to create an app."""
fake = Faker() fake = Faker()
@ -149,12 +149,12 @@ class TestMessagesCleanServiceIntegration:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(app) db_session_with_containers.add(app)
db.session.commit() db_session_with_containers.commit()
return app return app
def _create_conversation(self, app): def _create_conversation(self, db_session_with_containers: Session, app):
"""Helper to create a conversation.""" """Helper to create a conversation."""
conversation = Conversation( conversation = Conversation(
app_id=app.id, app_id=app.id,
@ -168,12 +168,14 @@ class TestMessagesCleanServiceIntegration:
from_source="api", from_source="api",
from_end_user_id=str(uuid.uuid4()), from_end_user_id=str(uuid.uuid4()),
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.commit() db_session_with_containers.commit()
return conversation return conversation
def _create_message(self, app, conversation, created_at=None, with_relations=True): def _create_message(
self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True
):
"""Helper to create a message with optional related records.""" """Helper to create a message with optional related records."""
if created_at is None: if created_at is None:
created_at = datetime.datetime.now() created_at = datetime.datetime.now()
@ -197,16 +199,16 @@ class TestMessagesCleanServiceIntegration:
from_account_id=conversation.from_end_user_id, from_account_id=conversation.from_end_user_id,
created_at=created_at, created_at=created_at,
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.flush() db_session_with_containers.flush()
if with_relations: if with_relations:
self._create_message_relations(message) self._create_message_relations(db_session_with_containers, message)
db.session.commit() db_session_with_containers.commit()
return message return message
def _create_message_relations(self, message): def _create_message_relations(self, db_session_with_containers: Session, message):
"""Helper to create all message-related records.""" """Helper to create all message-related records."""
# MessageFeedback # MessageFeedback
feedback = MessageFeedback( feedback = MessageFeedback(
@ -217,7 +219,7 @@ class TestMessagesCleanServiceIntegration:
from_source="api", from_source="api",
from_end_user_id=str(uuid.uuid4()), from_end_user_id=str(uuid.uuid4()),
) )
db.session.add(feedback) db_session_with_containers.add(feedback)
# MessageAnnotation # MessageAnnotation
annotation = MessageAnnotation( annotation = MessageAnnotation(
@ -228,7 +230,7 @@ class TestMessagesCleanServiceIntegration:
content="Test annotation", content="Test annotation",
account_id=message.from_account_id, account_id=message.from_account_id,
) )
db.session.add(annotation) db_session_with_containers.add(annotation)
# MessageChain # MessageChain
chain = MessageChain( chain = MessageChain(
@ -237,8 +239,8 @@ class TestMessagesCleanServiceIntegration:
input=json.dumps({"test": "input"}), input=json.dumps({"test": "input"}),
output=json.dumps({"test": "output"}), output=json.dumps({"test": "output"}),
) )
db.session.add(chain) db_session_with_containers.add(chain)
db.session.flush() db_session_with_containers.flush()
# MessageFile # MessageFile
file = MessageFile( file = MessageFile(
@ -250,7 +252,7 @@ class TestMessagesCleanServiceIntegration:
created_by_role="end_user", created_by_role="end_user",
created_by=str(uuid.uuid4()), created_by=str(uuid.uuid4()),
) )
db.session.add(file) db_session_with_containers.add(file)
# SavedMessage # SavedMessage
saved = SavedMessage( saved = SavedMessage(
@ -259,9 +261,9 @@ class TestMessagesCleanServiceIntegration:
created_by_role="end_user", created_by_role="end_user",
created_by=str(uuid.uuid4()), created_by=str(uuid.uuid4()),
) )
db.session.add(saved) db_session_with_containers.add(saved)
db.session.flush() db_session_with_containers.flush()
# AppAnnotationHitHistory # AppAnnotationHitHistory
hit = AppAnnotationHitHistory( hit = AppAnnotationHitHistory(
@ -275,7 +277,7 @@ class TestMessagesCleanServiceIntegration:
annotation_question="Test annotation question", annotation_question="Test annotation question",
annotation_content="Test annotation content", annotation_content="Test annotation content",
) )
db.session.add(hit) db_session_with_containers.add(hit)
# DatasetRetrieverResource # DatasetRetrieverResource
resource = DatasetRetrieverResource( resource = DatasetRetrieverResource(
@ -296,25 +298,29 @@ class TestMessagesCleanServiceIntegration:
retriever_from="dataset", retriever_from="dataset",
created_by=message.from_account_id, created_by=message.from_account_id,
) )
db.session.add(resource) db_session_with_containers.add(resource)
def test_billing_disabled_deletes_all_messages_in_time_range( def test_billing_disabled_deletes_all_messages_in_time_range(
self, db_session_with_containers, mock_billing_disabled self, db_session_with_containers: Session, mock_billing_disabled
): ):
"""Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan."""
# Arrange - Create tenant with messages (plan doesn't matter for billing disabled) # Arrange - Create tenant with messages (plan doesn't matter for billing disabled)
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create messages: in-range (should be deleted) and out-of-range (should be kept) # Create messages: in-range (should be deleted) and out-of-range (should be kept)
in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0)
out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0)
in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) in_range_msg = self._create_message(
db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True
)
in_range_msg_id = in_range_msg.id in_range_msg_id = in_range_msg.id
out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) out_of_range_msg = self._create_message(
db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True
)
out_of_range_msg_id = out_of_range_msg.id out_of_range_msg_id = out_of_range_msg.id
# Act - create_message_clean_policy should return BillingDisabledPolicy # Act - create_message_clean_policy should return BillingDisabledPolicy
@ -336,17 +342,34 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 1 assert stats["total_deleted"] == 1
# In-range message deleted # In-range message deleted
assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0
# Out-of-range message kept # Out-of-range message kept
assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1
# Related records of in-range message deleted # Related records of in-range message deleted
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 assert (
assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 db_session_with_containers.query(MessageFeedback)
.where(MessageFeedback.message_id == in_range_msg_id)
.count()
== 0
)
assert (
db_session_with_containers.query(MessageAnnotation)
.where(MessageAnnotation.message_id == in_range_msg_id)
.count()
== 0
)
# Related records of out-of-range message kept # Related records of out-of-range message kept
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 assert (
db_session_with_containers.query(MessageFeedback)
.where(MessageFeedback.message_id == out_of_range_msg_id)
.count()
== 1
)
def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_no_messages_returns_empty_stats(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test cleaning when there are no messages to delete (B1).""" """Test cleaning when there are no messages to delete (B1)."""
# Arrange # Arrange
end_before = datetime.datetime.now() - datetime.timedelta(days=30) end_before = datetime.datetime.now() - datetime.timedelta(days=30)
@ -371,36 +394,42 @@ class TestMessagesCleanServiceIntegration:
assert stats["filtered_messages"] == 0 assert stats["filtered_messages"] == 0
assert stats["total_deleted"] == 0 assert stats["total_deleted"] == 0
def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_mixed_sandbox_and_paid_tenants(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test cleaning with mixed sandbox and paid tenants (B2).""" """Test cleaning with mixed sandbox and paid tenants (B2)."""
# Arrange - Create sandbox tenants with expired messages # Arrange - Create sandbox tenants with expired messages
sandbox_tenants = [] sandbox_tenants = []
sandbox_message_ids = [] sandbox_message_ids = []
for i in range(2): for i in range(2):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
sandbox_tenants.append(tenant) sandbox_tenants.append(tenant)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create 3 expired messages per sandbox tenant # Create 3 expired messages per sandbox tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(3): for j in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j)
)
sandbox_message_ids.append(msg.id) sandbox_message_ids.append(msg.id)
# Create paid tenants with expired messages (should NOT be deleted) # Create paid tenants with expired messages (should NOT be deleted)
paid_tenants = [] paid_tenants = []
paid_message_ids = [] paid_message_ids = []
for i in range(2): for i in range(2):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL)
paid_tenants.append(tenant) paid_tenants.append(tenant)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create 2 expired messages per paid tenant # Create 2 expired messages per paid tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(2): for j in range(2):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j)
)
paid_message_ids.append(msg.id) paid_message_ids.append(msg.id)
# Mock billing service - return plan and expiration_date # Mock billing service - return plan and expiration_date
@ -442,29 +471,39 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 6 assert stats["total_deleted"] == 6
# Only sandbox messages should be deleted # Only sandbox messages should be deleted
assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
# Paid messages should remain # Paid messages should remain
assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
# Related records of sandbox messages should be deleted # Related records of sandbox messages should be deleted
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0
assert ( assert (
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() db_session_with_containers.query(MessageFeedback)
.where(MessageFeedback.message_id.in_(sandbox_message_ids))
.count()
== 0
)
assert (
db_session_with_containers.query(MessageAnnotation)
.where(MessageAnnotation.message_id.in_(sandbox_message_ids))
.count()
== 0 == 0
) )
def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_cursor_pagination_multiple_batches(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test cursor pagination works correctly across multiple batches (B3).""" """Test cursor pagination works correctly across multiple batches (B3)."""
# Arrange - Create sandbox tenant with messages that will span multiple batches # Arrange - Create sandbox tenant with messages that will span multiple batches
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create 10 expired messages with different timestamps # Create 10 expired messages with different timestamps
base_date = datetime.datetime.now() - datetime.timedelta(days=35) base_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = [] message_ids = []
for i in range(10): for i in range(10):
msg = self._create_message( msg = self._create_message(
db_session_with_containers,
app, app,
conv, conv,
created_at=base_date + datetime.timedelta(hours=i), created_at=base_date + datetime.timedelta(hours=i),
@ -498,20 +537,22 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 10 assert stats["total_deleted"] == 10
# All messages should be deleted # All messages should be deleted
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0
def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
"""Test dry_run mode does not delete messages (B4).""" """Test dry_run mode does not delete messages (B4)."""
# Arrange # Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create expired messages # Create expired messages
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = [] message_ids = []
for i in range(3): for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i)
)
message_ids.append(msg.id) message_ids.append(msg.id)
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
@ -540,21 +581,26 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0 # But NOT deleted assert stats["total_deleted"] == 0 # But NOT deleted
# All messages should still exist # All messages should still exist
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3
# Related records should also still exist # Related records should also still exist
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 assert (
db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count()
== 3
)
def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_partial_plan_data_safe_default(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test when billing returns partial data, unknown tenants are preserved (B5).""" """Test when billing returns partial data, unknown tenants are preserved (B5)."""
# Arrange - Create 3 tenants # Arrange - Create 3 tenants
tenants_data = [] tenants_data = []
for i in range(3): for i in range(3):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date) msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date)
tenants_data.append( tenants_data.append(
{ {
@ -600,28 +646,30 @@ class TestMessagesCleanServiceIntegration:
# Check which messages were deleted # Check which messages were deleted
assert ( assert (
db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
) # Sandbox tenant's message deleted ) # Sandbox tenant's message deleted
assert ( assert (
db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
) # Professional tenant's message preserved ) # Professional tenant's message preserved
assert ( assert (
db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
) # Unknown tenant's message preserved (safe default) ) # Unknown tenant's message preserved (safe default)
def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_empty_plan_data_skips_deletion(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test when billing returns empty data, skip deletion entirely (B6).""" """Test when billing returns empty data, skip deletion entirely (B6)."""
# Arrange # Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date) msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date)
msg_id = msg.id msg_id = msg.id
db.session.commit() db_session_with_containers.commit()
# Mock billing service to return empty data (simulating failure/no data scenario) # Mock billing service to return empty data (simulating failure/no data scenario)
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
@ -644,17 +692,20 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0 assert stats["total_deleted"] == 0
# Message should still exist (safe default - don't delete if plan is unknown) # Message should still exist (safe default - don't delete if plan is unknown)
assert db.session.query(Message).where(Message.id == msg_id).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1
def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_time_range_boundary_behavior(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" """Test that messages are correctly filtered by [start_from, end_before) time range (B7)."""
# Arrange # Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create messages: before range, in range, after range # Create messages: before range, in range, after range
msg_before = self._create_message( msg_before = self._create_message(
db_session_with_containers,
app, app,
conv, conv,
created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from
@ -663,6 +714,7 @@ class TestMessagesCleanServiceIntegration:
msg_before_id = msg_before.id msg_before_id = msg_before.id
msg_at_start = self._create_message( msg_at_start = self._create_message(
db_session_with_containers,
app, app,
conv, conv,
created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive)
@ -671,6 +723,7 @@ class TestMessagesCleanServiceIntegration:
msg_at_start_id = msg_at_start.id msg_at_start_id = msg_at_start.id
msg_in_range = self._create_message( msg_in_range = self._create_message(
db_session_with_containers,
app, app,
conv, conv,
created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range
@ -679,6 +732,7 @@ class TestMessagesCleanServiceIntegration:
msg_in_range_id = msg_in_range.id msg_in_range_id = msg_in_range.id
msg_at_end = self._create_message( msg_at_end = self._create_message(
db_session_with_containers,
app, app,
conv, conv,
created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive)
@ -687,6 +741,7 @@ class TestMessagesCleanServiceIntegration:
msg_at_end_id = msg_at_end.id msg_at_end_id = msg_at_end.id
msg_after = self._create_message( msg_after = self._create_message(
db_session_with_containers,
app, app,
conv, conv,
created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before
@ -694,7 +749,7 @@ class TestMessagesCleanServiceIntegration:
) )
msg_after_id = msg_after.id msg_after_id = msg_after.id
db.session.commit() db_session_with_containers.commit()
# Mock billing service # Mock billing service
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
@ -722,17 +777,17 @@ class TestMessagesCleanServiceIntegration:
# Verify specific messages using stored IDs # Verify specific messages using stored IDs
# Before range, kept # Before range, kept
assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1
# At start (inclusive), deleted # At start (inclusive), deleted
assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0
# In range, deleted # In range, deleted
assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0
# At end (exclusive), kept # At end (exclusive), kept
assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1
# After range, kept # After range, kept
assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1
def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
"""Test cleaning with different graceful period scenarios (B8).""" """Test cleaning with different graceful period scenarios (B8)."""
# Arrange - Create 5 different tenants with different plan and expiration scenarios # Arrange - Create 5 different tenants with different plan and expiration scenarios
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
@ -740,50 +795,60 @@ class TestMessagesCleanServiceIntegration:
# Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago)
# Should NOT be deleted # Should NOT be deleted
account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app1 = self._create_app(tenant1, account1) app1 = self._create_app(db_session_with_containers, tenant1, account1)
conv1 = self._create_conversation(app1) conv1 = self._create_conversation(db_session_with_containers, app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) msg1 = self._create_message(
db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False
)
msg1_id = msg1.id msg1_id = msg1.id
expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period
# Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago)
# Should be deleted # Should be deleted
account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app2 = self._create_app(tenant2, account2) app2 = self._create_app(db_session_with_containers, tenant2, account2)
conv2 = self._create_conversation(app2) conv2 = self._create_conversation(db_session_with_containers, app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) msg2 = self._create_message(
db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False
)
msg2_id = msg2.id msg2_id = msg2.id
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period
# Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription)
# Should be deleted # Should be deleted
account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app3 = self._create_app(tenant3, account3) app3 = self._create_app(db_session_with_containers, tenant3, account3)
conv3 = self._create_conversation(app3) conv3 = self._create_conversation(db_session_with_containers, app3)
msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) msg3 = self._create_message(
db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False
)
msg3_id = msg3.id msg3_id = msg3.id
# Scenario 4: Non-sandbox plan (professional) with no expiration (future date) # Scenario 4: Non-sandbox plan (professional) with no expiration (future date)
# Should NOT be deleted # Should NOT be deleted
account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL)
app4 = self._create_app(tenant4, account4) app4 = self._create_app(db_session_with_containers, tenant4, account4)
conv4 = self._create_conversation(app4) conv4 = self._create_conversation(db_session_with_containers, app4)
msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) msg4 = self._create_message(
db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False
)
msg4_id = msg4.id msg4_id = msg4.id
future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year
# Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago)
# Should NOT be deleted (boundary is exclusive: > graceful_period) # Should NOT be deleted (boundary is exclusive: > graceful_period)
account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app5 = self._create_app(tenant5, account5) app5 = self._create_app(db_session_with_containers, tenant5, account5)
conv5 = self._create_conversation(app5) conv5 = self._create_conversation(db_session_with_containers, app5)
msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) msg5 = self._create_message(
db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False
)
msg5_id = msg5.id msg5_id = msg5.id
expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary
db.session.commit() db_session_with_containers.commit()
# Mock billing service with all scenarios # Mock billing service with all scenarios
plan_map = { plan_map = {
@ -832,23 +897,31 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 2 assert stats["total_deleted"] == 2
# Verify each scenario using saved IDs # Verify each scenario using saved IDs
assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted assert (
assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0
assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept ) # Beyond grace, deleted
assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept assert (
db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0
) # No subscription, deleted
assert (
db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1
) # Professional plan, kept
assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
"""Test that whitelisted tenants' messages are not deleted (B9).""" """Test that whitelisted tenants' messages are not deleted (B9)."""
# Arrange - Create 3 sandbox tenants with expired messages # Arrange - Create 3 sandbox tenants with expired messages
tenants_data = [] tenants_data = []
for i in range(3): for i in range(3):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date, with_relations=False
)
tenants_data.append( tenants_data.append(
{ {
@ -897,27 +970,33 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 1 assert stats["total_deleted"] == 1
# Verify tenant0's message still exists (whitelisted) # Verify tenant0's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
# Verify tenant1's message still exists (whitelisted) # Verify tenant1's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
# Verify tenant2's message was deleted (not whitelisted) # Verify tenant2's message was deleted (not whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): def test_from_days_cleans_old_messages(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test from_days correctly cleans messages older than N days (B11).""" """Test from_days correctly cleans messages older than N days (B11)."""
# Arrange # Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
# Create old messages (should be deleted - older than 30 days) # Create old messages (should be deleted - older than 30 days)
old_date = datetime.datetime.now() - datetime.timedelta(days=45) old_date = datetime.datetime.now() - datetime.timedelta(days=45)
old_msg_ids = [] old_msg_ids = []
for i in range(3): for i in range(3):
msg = self._create_message( msg = self._create_message(
app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False db_session_with_containers,
app,
conv,
created_at=old_date - datetime.timedelta(hours=i),
with_relations=False,
) )
old_msg_ids.append(msg.id) old_msg_ids.append(msg.id)
@ -926,11 +1005,15 @@ class TestMessagesCleanServiceIntegration:
recent_msg_ids = [] recent_msg_ids = []
for i in range(2): for i in range(2):
msg = self._create_message( msg = self._create_message(
app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False db_session_with_containers,
app,
conv,
created_at=recent_date - datetime.timedelta(hours=i),
with_relations=False,
) )
recent_msg_ids.append(msg.id) recent_msg_ids.append(msg.id)
db.session.commit() db_session_with_containers.commit()
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = { mock_billing.return_value = {
@ -955,30 +1038,34 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 3 assert stats["total_deleted"] == 3
# Old messages deleted # Old messages deleted
assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0
# Recent messages kept # Recent messages kept
assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2
def test_whitelist_precedence_over_grace_period( def test_whitelist_precedence_over_grace_period(
self, db_session_with_containers, mock_billing_enabled, mock_whitelist self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
): ):
"""Test that whitelist takes precedence over grace period logic.""" """Test that whitelist takes precedence over grace period logic."""
# Arrange - Create 2 sandbox tenants # Arrange - Create 2 sandbox tenants
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
# Tenant1: whitelisted, expired beyond grace period # Tenant1: whitelisted, expired beyond grace period
account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app1 = self._create_app(tenant1, account1) app1 = self._create_app(db_session_with_containers, tenant1, account1)
conv1 = self._create_conversation(app1) conv1 = self._create_conversation(db_session_with_containers, app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) msg1 = self._create_message(
db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False
)
expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace
# Tenant2: not whitelisted, within grace period # Tenant2: not whitelisted, within grace period
account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app2 = self._create_app(tenant2, account2) app2 = self._create_app(db_session_with_containers, tenant2, account2)
conv2 = self._create_conversation(app2) conv2 = self._create_conversation(db_session_with_containers, app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) msg2 = self._create_message(
db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False
)
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace
# Mock billing service # Mock billing service
@ -1019,22 +1106,26 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0 assert stats["total_deleted"] == 0
# Verify both messages still exist # Verify both messages still exist
assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period assert (
db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1
) # Within grace period
def test_empty_whitelist_deletes_eligible_messages( def test_empty_whitelist_deletes_eligible_messages(
self, db_session_with_containers, mock_billing_enabled, mock_whitelist self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
): ):
"""Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" """Test that empty whitelist behaves as no whitelist (all eligible messages deleted)."""
# Arrange - Create sandbox tenant with expired messages # Arrange - Create sandbox tenant with expired messages
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account) app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(app) conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35) expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg_ids = [] msg_ids = []
for i in range(3): for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i)
)
msg_ids.append(msg.id) msg_ids.append(msg.id)
# Mock billing service # Mock billing service
@ -1068,4 +1159,4 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 3 assert stats["total_deleted"] == 3
# Verify all messages were deleted # Verify all messages were deleted
assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -32,7 +33,7 @@ class TestMetadataService:
"document_service": mock_document_service, "document_service": mock_document_service,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -53,18 +54,16 @@ class TestMetadataService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -73,15 +72,17 @@ class TestMetadataService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): def _create_test_dataset(
self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant
):
""" """
Helper method to create a test dataset for testing. Helper method to create a test dataset for testing.
@ -105,14 +106,14 @@ class TestMetadataService:
built_in_field_enabled=False, built_in_field_enabled=False,
) )
from extensions.ext_database import db db_session_with_containers.add(dataset)
db_session_with_containers.commit()
db.session.add(dataset)
db.session.commit()
return dataset return dataset
def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): def _create_test_document(
self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account
):
""" """
Helper method to create a test document for testing. Helper method to create a test document for testing.
@ -141,14 +142,12 @@ class TestMetadataService:
doc_language="en", doc_language="en",
) )
from extensions.ext_database import db db_session_with_containers.add(document)
db_session_with_containers.commit()
db.session.add(document)
db.session.commit()
return document return document
def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful metadata creation with valid parameters. Test successful metadata creation with valid parameters.
""" """
@ -178,13 +177,14 @@ class TestMetadataService:
assert result.created_by == account.id assert result.created_by == account.id
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
assert result.created_at is not None assert result.created_at is not None
def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): def test_create_metadata_name_too_long(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test metadata creation fails when name exceeds 255 characters. Test metadata creation fails when name exceeds 255 characters.
""" """
@ -207,7 +207,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
MetadataService.create_metadata(dataset.id, metadata_args) MetadataService.create_metadata(dataset.id, metadata_args)
def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): def test_create_metadata_name_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test metadata creation fails when name already exists in the same dataset. Test metadata creation fails when name already exists in the same dataset.
""" """
@ -235,7 +237,7 @@ class TestMetadataService:
MetadataService.create_metadata(dataset.id, second_metadata_args) MetadataService.create_metadata(dataset.id, second_metadata_args)
def test_create_metadata_name_conflicts_with_built_in_field( def test_create_metadata_name_conflicts_with_built_in_field(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata creation fails when name conflicts with built-in field names. Test metadata creation fails when name conflicts with built-in field names.
@ -260,7 +262,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.create_metadata(dataset.id, metadata_args) MetadataService.create_metadata(dataset.id, metadata_args)
def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_metadata_name_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful metadata name update with valid parameters. Test successful metadata name update with valid parameters.
""" """
@ -291,12 +295,13 @@ class TestMetadataService:
assert result.updated_at is not None assert result.updated_at is not None
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.name == new_name assert result.name == new_name
def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): def test_update_metadata_name_too_long(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test metadata name update fails when new name exceeds 255 characters. Test metadata name update fails when new name exceeds 255 characters.
""" """
@ -323,7 +328,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) MetadataService.update_metadata_name(dataset.id, metadata.id, long_name)
def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): def test_update_metadata_name_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test metadata name update fails when new name already exists in the same dataset. Test metadata name update fails when new name already exists in the same dataset.
""" """
@ -351,7 +358,7 @@ class TestMetadataService:
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata")
def test_update_metadata_name_conflicts_with_built_in_field( def test_update_metadata_name_conflicts_with_built_in_field(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata name update fails when new name conflicts with built-in field names. Test metadata name update fails when new name conflicts with built-in field names.
@ -378,7 +385,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_update_metadata_name_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test metadata name update fails when metadata ID does not exist. Test metadata name update fails when metadata ID does not exist.
""" """
@ -406,7 +415,7 @@ class TestMetadataService:
# Assert: Verify the method returns None when metadata is not found # Assert: Verify the method returns None when metadata is not found
assert result is None assert result is None
def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful metadata deletion with valid parameters. Test successful metadata deletion with valid parameters.
""" """
@ -434,12 +443,11 @@ class TestMetadataService:
assert result.id == metadata.id assert result.id == metadata.id
# Verify metadata was deleted from database # Verify metadata was deleted from database
from extensions.ext_database import db
deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
assert deleted_metadata is None assert deleted_metadata is None
def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test metadata deletion fails when metadata ID does not exist. Test metadata deletion fails when metadata ID does not exist.
""" """
@ -467,7 +475,7 @@ class TestMetadataService:
assert result is None assert result is None
def test_delete_metadata_with_document_bindings( def test_delete_metadata_with_document_bindings(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata deletion successfully removes document metadata bindings. Test metadata deletion successfully removes document metadata bindings.
@ -500,15 +508,13 @@ class TestMetadataService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(binding)
db_session_with_containers.commit()
db.session.add(binding)
db.session.commit()
# Set document metadata # Set document metadata
document.doc_metadata = {"test_metadata": "test_value"} document.doc_metadata = {"test_metadata": "test_value"}
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
# Act: Execute the method under test # Act: Execute the method under test
result = MetadataService.delete_metadata(dataset.id, metadata.id) result = MetadataService.delete_metadata(dataset.id, metadata.id)
@ -517,13 +523,13 @@ class TestMetadataService:
assert result is not None assert result is not None
# Verify metadata was deleted from database # Verify metadata was deleted from database
deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
assert deleted_metadata is None assert deleted_metadata is None
# Note: The service attempts to update document metadata but may not succeed # Note: The service attempts to update document metadata but may not succeed
# due to mock configuration. The main functionality (metadata deletion) is verified. # due to mock configuration. The main functionality (metadata deletion) is verified.
def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of built-in metadata fields. Test successful retrieval of built-in metadata fields.
""" """
@ -548,7 +554,9 @@ class TestMetadataService:
assert "string" in field_types assert "string" in field_types
assert "time" in field_types assert "time" in field_types
def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): def test_enable_built_in_field_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful enabling of built-in fields for a dataset. Test successful enabling of built-in fields for a dataset.
""" """
@ -579,16 +587,15 @@ class TestMetadataService:
MetadataService.enable_built_in_field(dataset) MetadataService.enable_built_in_field(dataset)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
from extensions.ext_database import db
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True assert dataset.built_in_field_enabled is True
# Note: Document metadata update depends on DocumentService mock working correctly # Note: Document metadata update depends on DocumentService mock working correctly
# The main functionality (enabling built-in fields) is verified # The main functionality (enabling built-in fields) is verified
def test_enable_built_in_field_already_enabled( def test_enable_built_in_field_already_enabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test enabling built-in fields when they are already enabled. Test enabling built-in fields when they are already enabled.
@ -607,10 +614,9 @@ class TestMetadataService:
# Enable built-in fields first # Enable built-in fields first
dataset.built_in_field_enabled = True dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id # Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
@ -619,11 +625,11 @@ class TestMetadataService:
MetadataService.enable_built_in_field(dataset) MetadataService.enable_built_in_field(dataset)
# Assert: Verify the method returns early without changes # Assert: Verify the method returns early without changes
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True assert dataset.built_in_field_enabled is True
def test_enable_built_in_field_with_no_documents( def test_enable_built_in_field_with_no_documents(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test enabling built-in fields for a dataset with no documents. Test enabling built-in fields for a dataset with no documents.
@ -647,12 +653,13 @@ class TestMetadataService:
MetadataService.enable_built_in_field(dataset) MetadataService.enable_built_in_field(dataset)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
from extensions.ext_database import db
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True assert dataset.built_in_field_enabled is True
def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): def test_disable_built_in_field_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful disabling of built-in fields for a dataset. Test successful disabling of built-in fields for a dataset.
""" """
@ -673,10 +680,9 @@ class TestMetadataService:
# Enable built-in fields first # Enable built-in fields first
dataset.built_in_field_enabled = True dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Set document metadata with built-in fields # Set document metadata with built-in fields
document.doc_metadata = { document.doc_metadata = {
@ -686,8 +692,8 @@ class TestMetadataService:
BuiltInField.last_update_date: 1234567890.0, BuiltInField.last_update_date: 1234567890.0,
BuiltInField.source: "test_source", BuiltInField.source: "test_source",
} }
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id # Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [
@ -698,14 +704,14 @@ class TestMetadataService:
MetadataService.disable_built_in_field(dataset) MetadataService.disable_built_in_field(dataset)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False assert dataset.built_in_field_enabled is False
# Note: Document metadata update depends on DocumentService mock working correctly # Note: Document metadata update depends on DocumentService mock working correctly
# The main functionality (disabling built-in fields) is verified # The main functionality (disabling built-in fields) is verified
def test_disable_built_in_field_already_disabled( def test_disable_built_in_field_already_disabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test disabling built-in fields when they are already disabled. Test disabling built-in fields when they are already disabled.
@ -732,13 +738,12 @@ class TestMetadataService:
MetadataService.disable_built_in_field(dataset) MetadataService.disable_built_in_field(dataset)
# Assert: Verify the method returns early without changes # Assert: Verify the method returns early without changes
from extensions.ext_database import db
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False assert dataset.built_in_field_enabled is False
def test_disable_built_in_field_with_no_documents( def test_disable_built_in_field_with_no_documents(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test disabling built-in fields for a dataset with no documents. Test disabling built-in fields for a dataset with no documents.
@ -757,10 +762,9 @@ class TestMetadataService:
# Enable built-in fields first # Enable built-in fields first
dataset.built_in_field_enabled = True dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list # Mock DocumentService.get_working_documents_by_dataset_id to return empty list
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
@ -769,10 +773,12 @@ class TestMetadataService:
MetadataService.disable_built_in_field(dataset) MetadataService.disable_built_in_field(dataset)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False assert dataset.built_in_field_enabled is False
def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_documents_metadata_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful update of documents metadata. Test successful update of documents metadata.
""" """
@ -815,24 +821,25 @@ class TestMetadataService:
MetadataService.update_documents_metadata(dataset, operation_data) MetadataService.update_documents_metadata(dataset, operation_data)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
from extensions.ext_database import db
# Verify document metadata was updated # Verify document metadata was updated
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.doc_metadata is not None assert document.doc_metadata is not None
assert "test_metadata" in document.doc_metadata assert "test_metadata" in document.doc_metadata
assert document.doc_metadata["test_metadata"] == "test_value" assert document.doc_metadata["test_metadata"] == "test_value"
# Verify metadata binding was created # Verify metadata binding was created
binding = ( binding = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() db_session_with_containers.query(DatasetMetadataBinding)
.filter_by(metadata_id=metadata.id, document_id=document.id)
.first()
) )
assert binding is not None assert binding is not None
assert binding.tenant_id == tenant.id assert binding.tenant_id == tenant.id
assert binding.dataset_id == dataset.id assert binding.dataset_id == dataset.id
def test_update_documents_metadata_with_built_in_fields_enabled( def test_update_documents_metadata_with_built_in_fields_enabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test update of documents metadata when built-in fields are enabled. Test update of documents metadata when built-in fields are enabled.
@ -850,10 +857,9 @@ class TestMetadataService:
# Enable built-in fields # Enable built-in fields
dataset.built_in_field_enabled = True dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Setup mocks # Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
@ -884,7 +890,7 @@ class TestMetadataService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Verify document metadata was updated with both custom and built-in fields # Verify document metadata was updated with both custom and built-in fields
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.doc_metadata is not None assert document.doc_metadata is not None
assert "test_metadata" in document.doc_metadata assert "test_metadata" in document.doc_metadata
assert document.doc_metadata["test_metadata"] == "test_value" assert document.doc_metadata["test_metadata"] == "test_value"
@ -893,7 +899,7 @@ class TestMetadataService:
# The main functionality (custom metadata update) is verified # The main functionality (custom metadata update) is verified
def test_update_documents_metadata_document_not_found( def test_update_documents_metadata_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test update of documents metadata when document is not found. Test update of documents metadata when document is not found.
@ -936,7 +942,7 @@ class TestMetadataService:
MetadataService.update_documents_metadata(dataset, operation_data) MetadataService.update_documents_metadata(dataset, operation_data)
def test_knowledge_base_metadata_lock_check_dataset_id( def test_knowledge_base_metadata_lock_check_dataset_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata lock check for dataset operations. Test metadata lock check for dataset operations.
@ -959,7 +965,7 @@ class TestMetadataService:
assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}"
def test_knowledge_base_metadata_lock_check_document_id( def test_knowledge_base_metadata_lock_check_document_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata lock check for document operations. Test metadata lock check for document operations.
@ -982,7 +988,7 @@ class TestMetadataService:
assert call_args[0][0] == f"document_metadata_lock_{document_id}" assert call_args[0][0] == f"document_metadata_lock_{document_id}"
def test_knowledge_base_metadata_lock_check_lock_exists( def test_knowledge_base_metadata_lock_check_lock_exists(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata lock check when lock already exists. Test metadata lock check when lock already exists.
@ -999,7 +1005,7 @@ class TestMetadataService:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
def test_knowledge_base_metadata_lock_check_document_lock_exists( def test_knowledge_base_metadata_lock_check_document_lock_exists(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test metadata lock check when document lock already exists. Test metadata lock check when document lock already exists.
@ -1013,7 +1019,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."):
MetadataService.knowledge_base_metadata_lock_check(None, document_id) MetadataService.knowledge_base_metadata_lock_check(None, document_id)
def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_dataset_metadatas_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of dataset metadata information. Test successful retrieval of dataset metadata information.
""" """
@ -1046,10 +1054,8 @@ class TestMetadataService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(binding)
db_session_with_containers.commit()
db.session.add(binding)
db.session.commit()
# Act: Execute the method under test # Act: Execute the method under test
result = MetadataService.get_dataset_metadatas(dataset) result = MetadataService.get_dataset_metadatas(dataset)
@ -1071,7 +1077,7 @@ class TestMetadataService:
assert result["built_in_field_enabled"] is False assert result["built_in_field_enabled"] is False
def test_get_dataset_metadatas_with_built_in_fields_enabled( def test_get_dataset_metadatas_with_built_in_fields_enabled(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test retrieval of dataset metadata when built-in fields are enabled. Test retrieval of dataset metadata when built-in fields are enabled.
@ -1086,10 +1092,9 @@ class TestMetadataService:
# Enable built-in fields # Enable built-in fields
dataset.built_in_field_enabled = True dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Setup mocks # Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
@ -1114,7 +1119,9 @@ class TestMetadataService:
# Verify built-in field status # Verify built-in field status
assert result["built_in_field_enabled"] is True assert result["built_in_field_enabled"] is True
def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): def test_get_dataset_metadatas_no_metadata(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test retrieval of dataset metadata when no metadata exists. Test retrieval of dataset metadata when no metadata exists.
""" """

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant from models.model import Account, Tenant
@ -67,7 +68,7 @@ class TestModelLoadBalancingService:
"credential_schema": mock_credential_schema, "credential_schema": mock_credential_schema,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -88,18 +89,16 @@ class TestModelLoadBalancingService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -108,8 +107,8 @@ class TestModelLoadBalancingService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
@ -117,7 +116,7 @@ class TestModelLoadBalancingService:
return account, tenant return account, tenant
def _create_test_provider_and_setting( def _create_test_provider_and_setting(
self, db_session_with_containers, tenant_id, mock_external_service_dependencies self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies
): ):
""" """
Helper method to create a test provider and provider model setting. Helper method to create a test provider and provider model setting.
@ -132,8 +131,6 @@ class TestModelLoadBalancingService:
""" """
fake = Faker() fake = Faker()
from extensions.ext_database import db
# Create provider # Create provider
provider = Provider( provider = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -141,8 +138,8 @@ class TestModelLoadBalancingService:
provider_type="custom", provider_type="custom",
is_valid=True, is_valid=True,
) )
db.session.add(provider) db_session_with_containers.add(provider)
db.session.commit() db_session_with_containers.commit()
# Create provider model setting # Create provider model setting
provider_model_setting = ProviderModelSetting( provider_model_setting = ProviderModelSetting(
@ -153,12 +150,14 @@ class TestModelLoadBalancingService:
enabled=True, enabled=True,
load_balancing_enabled=False, load_balancing_enabled=False,
) )
db.session.add(provider_model_setting) db_session_with_containers.add(provider_model_setting)
db.session.commit() db_session_with_containers.commit()
return provider, provider_model_setting return provider, provider_model_setting
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): def test_enable_model_load_balancing_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful model load balancing enablement. Test successful model load balancing enablement.
@ -193,14 +192,15 @@ class TestModelLoadBalancingService:
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(provider) db_session_with_containers.refresh(provider)
db.session.refresh(provider_model_setting) db_session_with_containers.refresh(provider_model_setting)
assert provider.id is not None assert provider.id is not None
assert provider_model_setting.id is not None assert provider_model_setting.id is not None
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): def test_disable_model_load_balancing_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful model load balancing disablement. Test successful model load balancing disablement.
@ -235,15 +235,14 @@ class TestModelLoadBalancingService:
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(provider) db_session_with_containers.refresh(provider)
db.session.refresh(provider_model_setting) db_session_with_containers.refresh(provider_model_setting)
assert provider.id is not None assert provider.id is not None
assert provider_model_setting.id is not None assert provider_model_setting.id is not None
def test_enable_model_load_balancing_provider_not_found( def test_enable_model_load_balancing_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when provider does not exist. Test error handling when provider does not exist.
@ -275,11 +274,12 @@ class TestModelLoadBalancingService:
assert "Provider nonexistent_provider does not exist." in str(exc_info.value) assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred # Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback() db_session_with_containers.rollback()
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_load_balancing_configs_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of load balancing configurations. Test successful retrieval of load balancing configurations.
@ -298,7 +298,6 @@ class TestModelLoadBalancingService:
) )
# Create load balancing config # Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig( load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -309,11 +308,11 @@ class TestModelLoadBalancingService:
encrypted_config='{"api_key": "test_key"}', encrypted_config='{"api_key": "test_key"}',
enabled=True, enabled=True,
) )
db.session.add(load_balancing_config) db_session_with_containers.add(load_balancing_config)
db.session.commit() db_session_with_containers.commit()
# Verify the config was created # Verify the config was created
db.session.refresh(load_balancing_config) db_session_with_containers.refresh(load_balancing_config)
assert load_balancing_config.id is not None assert load_balancing_config.id is not None
# Setup mocks for get_load_balancing_configs method # Setup mocks for get_load_balancing_configs method
@ -358,11 +357,11 @@ class TestModelLoadBalancingService:
assert configs[0]["ttl"] == 0 assert configs[0]["ttl"] == 0
# Verify database state # Verify database state
db.session.refresh(load_balancing_config) db_session_with_containers.refresh(load_balancing_config)
assert load_balancing_config.id is not None assert load_balancing_config.id is not None
def test_get_load_balancing_configs_provider_not_found( def test_get_load_balancing_configs_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when provider does not exist in get_load_balancing_configs. Test error handling when provider does not exist in get_load_balancing_configs.
@ -394,12 +393,11 @@ class TestModelLoadBalancingService:
assert "Provider nonexistent_provider does not exist." in str(exc_info.value) assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred # Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback() db_session_with_containers.rollback()
def test_get_load_balancing_configs_with_inherit_config( def test_get_load_balancing_configs_with_inherit_config(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test load balancing configs retrieval with inherit configuration. Test load balancing configs retrieval with inherit configuration.
@ -419,7 +417,6 @@ class TestModelLoadBalancingService:
) )
# Create load balancing config # Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig( load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -430,8 +427,8 @@ class TestModelLoadBalancingService:
encrypted_config='{"api_key": "test_key"}', encrypted_config='{"api_key": "test_key"}',
enabled=True, enabled=True,
) )
db.session.add(load_balancing_config) db_session_with_containers.add(load_balancing_config)
db.session.commit() db_session_with_containers.commit()
# Setup mocks for inherit config scenario # Setup mocks for inherit config scenario
mock_provider_config = mock_external_service_dependencies["provider_config"] mock_provider_config = mock_external_service_dependencies["provider_config"]
@ -467,11 +464,11 @@ class TestModelLoadBalancingService:
assert configs[1]["name"] == "config1" assert configs[1]["name"] == "config1"
# Verify database state # Verify database state
db.session.refresh(load_balancing_config) db_session_with_containers.refresh(load_balancing_config)
assert load_balancing_config.id is not None assert load_balancing_config.id is not None
# Verify inherit config was created in database # Verify inherit config was created in database
inherit_configs = db.session.scalars( inherit_configs = db_session_with_containers.scalars(
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
).all() ).all()
assert len(inherit_configs) == 1 assert len(inherit_configs) == 1

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
@ -29,7 +30,7 @@ class TestModelProviderService:
"model_provider_factory": mock_model_provider_factory, "model_provider_factory": mock_model_provider_factory,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -50,18 +51,16 @@ class TestModelProviderService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -70,8 +69,8 @@ class TestModelProviderService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
@ -80,7 +79,7 @@ class TestModelProviderService:
def _create_test_provider( def _create_test_provider(
self, self,
db_session_with_containers, db_session_with_containers: Session,
mock_external_service_dependencies, mock_external_service_dependencies,
tenant_id: str, tenant_id: str,
provider_name: str = "openai", provider_name: str = "openai",
@ -109,16 +108,14 @@ class TestModelProviderService:
quota_used=0, quota_used=0,
) )
from extensions.ext_database import db db_session_with_containers.add(provider)
db_session_with_containers.commit()
db.session.add(provider)
db.session.commit()
return provider return provider
def _create_test_provider_model( def _create_test_provider_model(
self, self,
db_session_with_containers, db_session_with_containers: Session,
mock_external_service_dependencies, mock_external_service_dependencies,
tenant_id: str, tenant_id: str,
provider_name: str, provider_name: str,
@ -149,16 +146,14 @@ class TestModelProviderService:
is_valid=True, is_valid=True,
) )
from extensions.ext_database import db db_session_with_containers.add(provider_model)
db_session_with_containers.commit()
db.session.add(provider_model)
db.session.commit()
return provider_model return provider_model
def _create_test_provider_model_setting( def _create_test_provider_model_setting(
self, self,
db_session_with_containers, db_session_with_containers: Session,
mock_external_service_dependencies, mock_external_service_dependencies,
tenant_id: str, tenant_id: str,
provider_name: str, provider_name: str,
@ -190,14 +185,12 @@ class TestModelProviderService:
load_balancing_enabled=False, load_balancing_enabled=False,
) )
from extensions.ext_database import db db_session_with_containers.add(provider_model_setting)
db_session_with_containers.commit()
db.session.add(provider_model_setting)
db.session.commit()
return provider_model_setting return provider_model_setting
def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful provider list retrieval. Test successful provider list retrieval.
@ -275,7 +268,7 @@ class TestModelProviderService:
mock_provider_config.is_custom_configuration_available.assert_called_once() mock_provider_config.is_custom_configuration_available.assert_called_once()
def test_get_provider_list_with_model_type_filter( def test_get_provider_list_with_model_type_filter(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test provider list retrieval with model type filtering. Test provider list retrieval with model type filtering.
@ -374,7 +367,9 @@ class TestModelProviderService:
assert result[0].provider == "cohere" assert result[0].provider == "cohere"
assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types
def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_models_by_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of models by provider. Test successful retrieval of models by provider.
@ -485,7 +480,9 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_configurations.get_models.assert_called_once_with(provider="openai") mock_configurations.get_models.assert_called_once_with(provider="openai")
def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_provider_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of provider credentials. Test successful retrieval of provider credentials.
@ -543,7 +540,7 @@ class TestModelProviderService:
mock_method.assert_called_once_with(tenant.id, "openai") mock_method.assert_called_once_with(tenant.id, "openai")
def test_provider_credentials_validate_success( def test_provider_credentials_validate_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful validation of provider credentials. Test successful validation of provider credentials.
@ -585,7 +582,7 @@ class TestModelProviderService:
mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials)
def test_provider_credentials_validate_invalid_provider( def test_provider_credentials_validate_invalid_provider(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test validation failure for non-existent provider. Test validation failure for non-existent provider.
@ -617,7 +614,7 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
def test_get_default_model_of_model_type_success( def test_get_default_model_of_model_type_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful retrieval of default model for a specific model type. Test successful retrieval of default model for a specific model type.
@ -673,7 +670,7 @@ class TestModelProviderService:
mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM) mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM)
def test_update_default_model_of_model_type_success( def test_update_default_model_of_model_type_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful update of default model for a specific model type. Test successful update of default model for a specific model type.
@ -706,7 +703,9 @@ class TestModelProviderService:
tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4" tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4"
) )
def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_model_provider_icon_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of model provider icon. Test successful retrieval of model provider icon.
@ -743,7 +742,9 @@ class TestModelProviderService:
# Verify mock interactions # Verify mock interactions
mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies): def test_switch_preferred_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful switching of preferred provider type. Test successful switching of preferred provider type.
@ -779,7 +780,7 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configuration.switch_preferred_provider_type.assert_called_once() mock_provider_configuration.switch_preferred_provider_type.assert_called_once()
def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies): def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful enabling of a model. Test successful enabling of a model.
@ -815,7 +816,9 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4") mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4")
def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_model_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of model credentials. Test successful retrieval of model credentials.
@ -872,7 +875,9 @@ class TestModelProviderService:
# Verify the method was called with correct parameters # Verify the method was called with correct parameters
mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None)
def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): def test_model_credentials_validate_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful validation of model credentials. Test successful validation of model credentials.
@ -914,7 +919,9 @@ class TestModelProviderService:
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials
) )
def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): def test_save_model_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful saving of model credentials. Test successful saving of model credentials.
@ -955,7 +962,9 @@ class TestModelProviderService:
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname"
) )
def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): def test_remove_model_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful removal of model credentials. Test successful removal of model credentials.
@ -993,7 +1002,9 @@ class TestModelProviderService:
model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6"
) )
def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_models_by_model_type_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of models by model type. Test successful retrieval of models by model type.
@ -1070,7 +1081,9 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_model_parameter_rules_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of model parameter rules. Test successful retrieval of model parameter rules.
@ -1137,7 +1150,7 @@ class TestModelProviderService:
) )
def test_get_model_parameter_rules_no_credentials( def test_get_model_parameter_rules_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test parameter rules retrieval when no credentials are available. Test parameter rules retrieval when no credentials are available.
@ -1181,7 +1194,7 @@ class TestModelProviderService:
) )
def test_get_model_parameter_rules_provider_not_found( def test_get_model_parameter_rules_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test parameter rules retrieval when provider does not exist. Test parameter rules retrieval when provider does not exist.

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models.model import EndUser, Message from models.model import EndUser, Message
from models.web import SavedMessage from models.web import SavedMessage
@ -38,7 +39,7 @@ class TestSavedMessageService:
"message_service": mock_message_service, "message_service": mock_message_service,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -85,7 +86,7 @@ class TestSavedMessageService:
return app, account return app, account
def _create_test_end_user(self, db_session_with_containers, app): def _create_test_end_user(self, db_session_with_containers: Session, app):
""" """
Helper method to create a test end user for testing. Helper method to create a test end user for testing.
@ -108,14 +109,12 @@ class TestSavedMessageService:
is_anonymous=False, is_anonymous=False,
) )
from extensions.ext_database import db db_session_with_containers.add(end_user)
db_session_with_containers.commit()
db.session.add(end_user)
db.session.commit()
return end_user return end_user
def _create_test_message(self, db_session_with_containers, app, user): def _create_test_message(self, db_session_with_containers: Session, app, user):
""" """
Helper method to create a test message for testing. Helper method to create a test message for testing.
@ -143,10 +142,8 @@ class TestSavedMessageService:
mode="chat", mode="chat",
) )
from extensions.ext_database import db db_session_with_containers.add(conversation)
db_session_with_containers.commit()
db.session.add(conversation)
db.session.commit()
# Create message # Create message
message = Message( message = Message(
@ -168,13 +165,13 @@ class TestSavedMessageService:
status="success", status="success",
) )
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
return message return message
def test_pagination_by_last_id_success_with_account_user( def test_pagination_by_last_id_success_with_account_user(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful pagination by last ID with account user. Test successful pagination by last ID with account user.
@ -207,10 +204,8 @@ class TestSavedMessageService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add_all([saved_message1, saved_message2])
db_session_with_containers.commit()
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
# Mock MessageService.pagination_by_last_id return value # Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -240,15 +235,15 @@ class TestSavedMessageService:
assert actual_include_ids == expected_include_ids assert actual_include_ids == expected_include_ids
# Verify database state # Verify database state
db.session.refresh(saved_message1) db_session_with_containers.refresh(saved_message1)
db.session.refresh(saved_message2) db_session_with_containers.refresh(saved_message2)
assert saved_message1.id is not None assert saved_message1.id is not None
assert saved_message2.id is not None assert saved_message2.id is not None
assert saved_message1.created_by_role == "account" assert saved_message1.created_by_role == "account"
assert saved_message2.created_by_role == "account" assert saved_message2.created_by_role == "account"
def test_pagination_by_last_id_success_with_end_user( def test_pagination_by_last_id_success_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful pagination by last ID with end user. Test successful pagination by last ID with end user.
@ -282,10 +277,8 @@ class TestSavedMessageService:
created_by=end_user.id, created_by=end_user.id,
) )
from extensions.ext_database import db db_session_with_containers.add_all([saved_message1, saved_message2])
db_session_with_containers.commit()
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
# Mock MessageService.pagination_by_last_id return value # Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -317,14 +310,16 @@ class TestSavedMessageService:
assert actual_include_ids == expected_include_ids assert actual_include_ids == expected_include_ids
# Verify database state # Verify database state
db.session.refresh(saved_message1) db_session_with_containers.refresh(saved_message1)
db.session.refresh(saved_message2) db_session_with_containers.refresh(saved_message2)
assert saved_message1.id is not None assert saved_message1.id is not None
assert saved_message2.id is not None assert saved_message2.id is not None
assert saved_message1.created_by_role == "end_user" assert saved_message1.created_by_role == "end_user"
assert saved_message2.created_by_role == "end_user" assert saved_message2.created_by_role == "end_user"
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies): def test_save_success_with_new_message(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful save of a new message. Test successful save of a new message.
@ -347,10 +342,9 @@ class TestSavedMessageService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Check if saved message was created in database # Check if saved message was created in database
from extensions.ext_database import db
saved_message = ( saved_message = (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -373,10 +367,12 @@ class TestSavedMessageService:
) )
# Verify database state # Verify database state
db.session.refresh(saved_message) db_session_with_containers.refresh(saved_message)
assert saved_message.id is not None assert saved_message.id is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_last_id_error_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test error handling when no user is provided. Test error handling when no user is provided.
@ -396,12 +392,11 @@ class TestSavedMessageService:
assert "User is required" in str(exc_info.value) assert "User is required" in str(exc_info.value)
# Verify no database operations were performed # Verify no database operations were performed
from extensions.ext_database import db
saved_messages = db.session.query(SavedMessage).all() saved_messages = db_session_with_containers.query(SavedMessage).all()
assert len(saved_messages) == 0 assert len(saved_messages) == 0
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test error handling when saving message with no user. Test error handling when saving message with no user.
@ -422,10 +417,9 @@ class TestSavedMessageService:
assert result is None assert result is None
# Verify no saved message was created # Verify no saved message was created
from extensions.ext_database import db
saved_message = ( saved_message = (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -435,7 +429,9 @@ class TestSavedMessageService:
assert saved_message is None assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_success_existing_message(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful deletion of an existing saved message. Test successful deletion of an existing saved message.
@ -457,14 +453,12 @@ class TestSavedMessageService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(saved_message)
db_session_with_containers.commit()
db.session.add(saved_message)
db.session.commit()
# Verify saved message exists # Verify saved message exists
assert ( assert (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -481,7 +475,7 @@ class TestSavedMessageService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Check if saved message was deleted from database # Check if saved message was deleted from database
deleted_saved_message = ( deleted_saved_message = (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -494,11 +488,13 @@ class TestSavedMessageService:
assert deleted_saved_message is None assert deleted_saved_message is None
# Verify database state # Verify database state
db.session.commit() db_session_with_containers.commit()
# The message should still exist, only the saved_message should be deleted # The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_last_id_error_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test error handling when no user is provided. Test error handling when no user is provided.
@ -522,7 +518,7 @@ class TestSavedMessageService:
# Instead, we verify that the error was properly raised # Instead, we verify that the error was properly raised
pass pass
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test error handling when saving message with no user. Test error handling when saving message with no user.
@ -543,10 +539,9 @@ class TestSavedMessageService:
assert result is None assert result is None
# Verify no saved message was created # Verify no saved message was created
from extensions.ext_database import db
saved_message = ( saved_message = (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -556,7 +551,9 @@ class TestSavedMessageService:
assert saved_message is None assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_success_existing_message(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful deletion of an existing saved message. Test successful deletion of an existing saved message.
@ -578,14 +575,12 @@ class TestSavedMessageService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(saved_message)
db_session_with_containers.commit()
db.session.add(saved_message)
db.session.commit()
# Verify saved message exists # Verify saved message exists
assert ( assert (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -602,7 +597,7 @@ class TestSavedMessageService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Check if saved message was deleted from database # Check if saved message was deleted from database
deleted_saved_message = ( deleted_saved_message = (
db.session.query(SavedMessage) db_session_with_containers.query(SavedMessage)
.where( .where(
SavedMessage.app_id == app.id, SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id, SavedMessage.message_id == message.id,
@ -615,6 +610,6 @@ class TestSavedMessageService:
assert deleted_saved_message is None assert deleted_saved_message is None
# Verify database state # Verify database state
db.session.commit() db_session_with_containers.commit()
# The message should still exist, only the saved_message should be deleted # The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None

View File

@ -4,6 +4,7 @@ from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -29,7 +30,7 @@ class TestTagService:
"current_user": mock_current_user, "current_user": mock_current_user,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -50,18 +51,16 @@ class TestTagService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -70,8 +69,8 @@ class TestTagService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
@ -82,7 +81,7 @@ class TestTagService:
return account, tenant return account, tenant
def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id):
""" """
Helper method to create a test dataset for testing. Helper method to create a test dataset for testing.
@ -107,14 +106,12 @@ class TestTagService:
created_by=mock_external_service_dependencies["current_user"].id, created_by=mock_external_service_dependencies["current_user"].id,
) )
from extensions.ext_database import db db_session_with_containers.add(dataset)
db_session_with_containers.commit()
db.session.add(dataset)
db.session.commit()
return dataset return dataset
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id):
""" """
Helper method to create a test app for testing. Helper method to create a test app for testing.
@ -141,15 +138,13 @@ class TestTagService:
created_by=mock_external_service_dependencies["current_user"].id, created_by=mock_external_service_dependencies["current_user"].id,
) )
from extensions.ext_database import db db_session_with_containers.add(app)
db_session_with_containers.commit()
db.session.add(app)
db.session.commit()
return app return app
def _create_test_tags( def _create_test_tags(
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3
): ):
""" """
Helper method to create test tags for testing. Helper method to create test tags for testing.
@ -176,16 +171,14 @@ class TestTagService:
) )
tags.append(tag) tags.append(tag)
from extensions.ext_database import db
for tag in tags: for tag in tags:
db.session.add(tag) db_session_with_containers.add(tag)
db.session.commit() db_session_with_containers.commit()
return tags return tags
def _create_test_tag_bindings( def _create_test_tag_bindings(
self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id
): ):
""" """
Helper method to create test tag bindings for testing. Helper method to create test tag bindings for testing.
@ -211,15 +204,13 @@ class TestTagService:
) )
tag_bindings.append(tag_binding) tag_bindings.append(tag_binding)
from extensions.ext_database import db
for tag_binding in tag_bindings: for tag_binding in tag_bindings:
db.session.add(tag_binding) db_session_with_containers.add(tag_binding)
db.session.commit() db_session_with_containers.commit()
return tag_bindings return tag_bindings
def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of tags with binding count. Test successful retrieval of tags with binding count.
@ -270,7 +261,9 @@ class TestTagService:
# The ordering is handled by the database, we just verify the results are returned # The ordering is handled by the database, we just verify the results are returned
assert len(result) == 3 assert len(result) == 3
def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tags_with_keyword_filter(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag retrieval with keyword filtering. Test tag retrieval with keyword filtering.
@ -291,12 +284,11 @@ class TestTagService:
) )
# Update tag names to make them searchable # Update tag names to make them searchable
from extensions.ext_database import db
tags[0].name = "python_development" tags[0].name = "python_development"
tags[1].name = "machine_learning" tags[1].name = "machine_learning"
tags[2].name = "web_development" tags[2].name = "web_development"
db.session.commit() db_session_with_containers.commit()
# Act: Execute the method under test with keyword filter # Act: Execute the method under test with keyword filter
result = TagService.get_tags("app", tenant.id, keyword="development") result = TagService.get_tags("app", tenant.id, keyword="development")
@ -314,7 +306,7 @@ class TestTagService:
assert len(result_no_match) == 0 assert len(result_no_match) == 0
def test_get_tags_with_special_characters_in_keyword( def test_get_tags_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
r""" r"""
Test tag retrieval with special characters in keyword to verify SQL injection prevention. Test tag retrieval with special characters in keyword to verify SQL injection prevention.
@ -330,8 +322,6 @@ class TestTagService:
db_session_with_containers, mock_external_service_dependencies db_session_with_containers, mock_external_service_dependencies
) )
from extensions.ext_database import db
# Create tags with special characters in names # Create tags with special characters in names
tag_with_percent = Tag( tag_with_percent = Tag(
name="50% discount", name="50% discount",
@ -340,7 +330,7 @@ class TestTagService:
created_by=account.id, created_by=account.id,
) )
tag_with_percent.id = str(uuid.uuid4()) tag_with_percent.id = str(uuid.uuid4())
db.session.add(tag_with_percent) db_session_with_containers.add(tag_with_percent)
tag_with_underscore = Tag( tag_with_underscore = Tag(
name="test_data_tag", name="test_data_tag",
@ -349,7 +339,7 @@ class TestTagService:
created_by=account.id, created_by=account.id,
) )
tag_with_underscore.id = str(uuid.uuid4()) tag_with_underscore.id = str(uuid.uuid4())
db.session.add(tag_with_underscore) db_session_with_containers.add(tag_with_underscore)
tag_with_backslash = Tag( tag_with_backslash = Tag(
name="path\\to\\tag", name="path\\to\\tag",
@ -358,7 +348,7 @@ class TestTagService:
created_by=account.id, created_by=account.id,
) )
tag_with_backslash.id = str(uuid.uuid4()) tag_with_backslash.id = str(uuid.uuid4())
db.session.add(tag_with_backslash) db_session_with_containers.add(tag_with_backslash)
# Create tag that should NOT match # Create tag that should NOT match
tag_no_match = Tag( tag_no_match = Tag(
@ -368,9 +358,9 @@ class TestTagService:
created_by=account.id, created_by=account.id,
) )
tag_no_match.id = str(uuid.uuid4()) tag_no_match.id = str(uuid.uuid4())
db.session.add(tag_no_match) db_session_with_containers.add(tag_no_match)
db.session.commit() db_session_with_containers.commit()
# Act & Assert: Test 1 - Search with % character # Act & Assert: Test 1 - Search with % character
result = TagService.get_tags("app", tenant.id, keyword="50%") result = TagService.get_tags("app", tenant.id, keyword="50%")
@ -392,7 +382,7 @@ class TestTagService:
assert len(result) == 1 assert len(result) == 1
assert all("50%" in item.name for item in result) assert all("50%" in item.name for item in result)
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test tag retrieval when no tags exist. Test tag retrieval when no tags exist.
@ -414,7 +404,9 @@ class TestTagService:
assert len(result) == 0 assert len(result) == 0
assert isinstance(result, list) assert isinstance(result, list)
def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_target_ids_by_tag_ids_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of target IDs by tag IDs. Test successful retrieval of target IDs by tag IDs.
@ -469,7 +461,7 @@ class TestTagService:
assert second_dataset_count == 1 assert second_dataset_count == 1
def test_get_target_ids_by_tag_ids_empty_tag_ids( def test_get_target_ids_by_tag_ids_empty_tag_ids(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test target ID retrieval with empty tag IDs list. Test target ID retrieval with empty tag IDs list.
@ -493,7 +485,7 @@ class TestTagService:
assert isinstance(result, list) assert isinstance(result, list)
def test_get_target_ids_by_tag_ids_no_matching_tags( def test_get_target_ids_by_tag_ids_no_matching_tags(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test target ID retrieval when no tags match the criteria. Test target ID retrieval when no tags match the criteria.
@ -521,7 +513,7 @@ class TestTagService:
assert len(result) == 0 assert len(result) == 0
assert isinstance(result, list) assert isinstance(result, list)
def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of tags by tag name. Test successful retrieval of tags by tag name.
@ -542,11 +534,10 @@ class TestTagService:
) )
# Update tag names to make them searchable # Update tag names to make them searchable
from extensions.ext_database import db
tags[0].name = "python_tag" tags[0].name = "python_tag"
tags[1].name = "ml_tag" tags[1].name = "ml_tag"
db.session.commit() db_session_with_containers.commit()
# Act: Execute the method under test # Act: Execute the method under test
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
@ -558,7 +549,9 @@ class TestTagService:
assert result[0].type == "app" assert result[0].type == "app"
assert result[0].tenant_id == tenant.id assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tag_by_tag_name_no_matches(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag retrieval by name when no matches exist. Test tag retrieval by name when no matches exist.
@ -580,7 +573,9 @@ class TestTagService:
assert len(result) == 0 assert len(result) == 0
assert isinstance(result, list) assert isinstance(result, list)
def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tag_by_tag_name_empty_parameters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag retrieval by name with empty parameters. Test tag retrieval by name with empty parameters.
@ -605,7 +600,9 @@ class TestTagService:
assert result_empty_name is not None assert result_empty_name is not None
assert len(result_empty_name) == 0 assert len(result_empty_name) == 0
def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tags_by_target_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of tags by target ID. Test successful retrieval of tags by target ID.
@ -644,7 +641,9 @@ class TestTagService:
assert tag.tenant_id == tenant.id assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags] assert tag.id in [t.id for t in tags]
def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tags_by_target_id_no_bindings(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag retrieval by target ID when no tags are bound. Test tag retrieval by target ID when no tags are bound.
@ -669,7 +668,7 @@ class TestTagService:
assert len(result) == 0 assert len(result) == 0
assert isinstance(result, list) assert isinstance(result, list)
def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful tag creation. Test successful tag creation.
@ -698,17 +697,18 @@ class TestTagService:
assert result.id is not None assert result.id is not None
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
# Verify tag was actually saved to database # Verify tag was actually saved to database
saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first()
assert saved_tag is not None assert saved_tag is not None
assert saved_tag.name == "test_tag_name" assert saved_tag.name == "test_tag_name"
def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): def test_save_tags_duplicate_name_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag creation with duplicate name. Test tag creation with duplicate name.
@ -731,7 +731,7 @@ class TestTagService:
TagService.save_tags(tag_args) TagService.save_tags(tag_args)
assert "Tag name already exists" in str(exc_info.value) assert "Tag name already exists" in str(exc_info.value)
def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful tag update. Test successful tag update.
@ -763,17 +763,16 @@ class TestTagService:
assert result.id == tag.id assert result.id == tag.id
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.name == "updated_name" assert result.name == "updated_name"
# Verify tag was actually updated in database # Verify tag was actually updated in database
updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
assert updated_tag is not None assert updated_tag is not None
assert updated_tag.name == "updated_name" assert updated_tag.name == "updated_name"
def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test tag update for non-existent tag. Test tag update for non-existent tag.
@ -799,7 +798,9 @@ class TestTagService:
TagService.update_tags(update_args, non_existent_tag_id) TagService.update_tags(update_args, non_existent_tag_id)
assert "Tag not found" in str(exc_info.value) assert "Tag not found" in str(exc_info.value)
def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): def test_update_tags_duplicate_name_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag update with duplicate name. Test tag update with duplicate name.
@ -828,7 +829,9 @@ class TestTagService:
TagService.update_tags(update_args, tag2.id) TagService.update_tags(update_args, tag2.id)
assert "Tag name already exists" in str(exc_info.value) assert "Tag name already exists" in str(exc_info.value)
def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tag_binding_count_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of tag binding count. Test successful retrieval of tag binding count.
@ -863,7 +866,7 @@ class TestTagService:
assert result_tag_without_bindings == 0 assert result_tag_without_bindings == 0
def test_get_tag_binding_count_non_existent_tag( def test_get_tag_binding_count_non_existent_tag(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test binding count retrieval for non-existent tag. Test binding count retrieval for non-existent tag.
@ -889,7 +892,7 @@ class TestTagService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
assert result == 0 assert result == 0
def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful tag deletion. Test successful tag deletion.
@ -916,12 +919,11 @@ class TestTagService:
) )
# Verify tag and binding exist before deletion # Verify tag and binding exist before deletion
from extensions.ext_database import db
tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
assert tag_before is not None assert tag_before is not None
binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
assert binding_before is not None assert binding_before is not None
# Act: Execute the method under test # Act: Execute the method under test
@ -929,14 +931,14 @@ class TestTagService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Verify tag was deleted # Verify tag was deleted
tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
assert tag_after is None assert tag_after is None
# Verify tag binding was deleted # Verify tag binding was deleted
binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
assert binding_after is None assert binding_after is None
def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test tag deletion for non-existent tag. Test tag deletion for non-existent tag.
@ -960,7 +962,7 @@ class TestTagService:
TagService.delete_tag(non_existent_tag_id) TagService.delete_tag(non_existent_tag_id)
assert "Tag not found" in str(exc_info.value) assert "Tag not found" in str(exc_info.value)
def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful tag binding creation. Test successful tag binding creation.
@ -988,12 +990,11 @@ class TestTagService:
TagService.save_tag_binding(binding_args) TagService.save_tag_binding(binding_args)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
from extensions.ext_database import db
# Verify tag bindings were created # Verify tag bindings were created
for tag in tags: for tag in tags:
binding = ( binding = (
db.session.query(TagBinding) db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first() .first()
) )
@ -1001,7 +1002,9 @@ class TestTagService:
assert binding.tenant_id == tenant.id assert binding.tenant_id == tenant.id
assert binding.created_by == account.id assert binding.created_by == account.id
def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): def test_save_tag_binding_duplicate_handling(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag binding creation with duplicate bindings. Test tag binding creation with duplicate bindings.
@ -1032,15 +1035,16 @@ class TestTagService:
TagService.save_tag_binding(binding_args) TagService.save_tag_binding(binding_args)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
from extensions.ext_database import db
# Verify only one binding exists # Verify only one binding exists
bindings = db.session.scalars( bindings = db_session_with_containers.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all() ).all()
assert len(bindings) == 1 assert len(bindings) == 1
def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): def test_save_tag_binding_invalid_target_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tag binding creation with invalid target type. Test tag binding creation with invalid target type.
@ -1071,7 +1075,7 @@ class TestTagService:
TagService.save_tag_binding(binding_args) TagService.save_tag_binding(binding_args)
assert "Invalid binding type" in str(exc_info.value) assert "Invalid binding type" in str(exc_info.value)
def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful tag binding deletion. Test successful tag binding deletion.
@ -1098,10 +1102,11 @@ class TestTagService:
) )
# Verify binding exists before deletion # Verify binding exists before deletion
from extensions.ext_database import db
binding_before = ( binding_before = (
db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
) )
assert binding_before is not None assert binding_before is not None
@ -1112,12 +1117,14 @@ class TestTagService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Verify tag binding was deleted # Verify tag binding was deleted
binding_after = ( binding_after = (
db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
) )
assert binding_after is None assert binding_after is None
def test_delete_tag_binding_non_existent_binding( def test_delete_tag_binding_non_existent_binding(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tag binding deletion for non-existent binding. Test tag binding deletion for non-existent binding.
@ -1145,15 +1152,14 @@ class TestTagService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# No error should be raised, and database state should remain unchanged # No error should be raised, and database state should remain unchanged
from extensions.ext_database import db
bindings = db.session.scalars( bindings = db_session_with_containers.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all() ).all()
assert len(bindings) == 0 assert len(bindings) == 0
def test_check_target_exists_knowledge_success( def test_check_target_exists_knowledge_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful target existence check for knowledge type. Test successful target existence check for knowledge type.
@ -1179,7 +1185,7 @@ class TestTagService:
# No exception should be raised for existing dataset # No exception should be raised for existing dataset
def test_check_target_exists_knowledge_not_found( def test_check_target_exists_knowledge_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test target existence check for non-existent knowledge dataset. Test target existence check for non-existent knowledge dataset.
@ -1204,7 +1210,9 @@ class TestTagService:
TagService.check_target_exists("knowledge", non_existent_dataset_id) TagService.check_target_exists("knowledge", non_existent_dataset_id)
assert "Dataset not found" in str(exc_info.value) assert "Dataset not found" in str(exc_info.value)
def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): def test_check_target_exists_app_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful target existence check for app type. Test successful target existence check for app type.
@ -1228,7 +1236,9 @@ class TestTagService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# No exception should be raised for existing app # No exception should be raised for existing app
def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_check_target_exists_app_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test target existence check for non-existent app. Test target existence check for non-existent app.
@ -1252,7 +1262,9 @@ class TestTagService:
TagService.check_target_exists("app", non_existent_app_id) TagService.check_target_exists("app", non_existent_app_id)
assert "App not found" in str(exc_info.value) assert "App not found" in str(exc_info.value)
def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): def test_check_target_exists_invalid_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test target existence check for invalid type. Test target existence check for invalid type.

View File

@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE, UNKNOWN_VALUE from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from extensions.ext_database import db
from models.provider_ids import TriggerProviderID from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_provider_service import TriggerProviderService
@ -47,7 +47,7 @@ class TestTriggerProviderService:
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -84,7 +84,7 @@ class TestTriggerProviderService:
def _create_test_subscription( def _create_test_subscription(
self, self,
db_session_with_containers, db_session_with_containers: Session,
tenant_id, tenant_id,
user_id, user_id,
provider_id, provider_id,
@ -135,14 +135,14 @@ class TestTriggerProviderService:
expires_at=-1, expires_at=-1,
) )
db.session.add(subscription) db_session_with_containers.add(subscription)
db.session.commit() db_session_with_containers.commit()
db.session.refresh(subscription) db_session_with_containers.refresh(subscription)
return subscription return subscription
def test_rebuild_trigger_subscription_success_with_merged_credentials( def test_rebuild_trigger_subscription_success_with_merged_credentials(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful rebuild with credential merging (HIDDEN_VALUE handling). Test successful rebuild with credential merging (HIDDEN_VALUE handling).
@ -217,7 +217,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
# Verify database state was updated # Verify database state was updated
db.session.refresh(subscription) db_session_with_containers.refresh(subscription)
assert subscription.name == "updated_name" assert subscription.name == "updated_name"
assert subscription.parameters == {"param1": "updated_value"} assert subscription.parameters == {"param1": "updated_value"}
@ -244,7 +244,7 @@ class TestTriggerProviderService:
) )
def test_rebuild_trigger_subscription_with_all_new_credentials( def test_rebuild_trigger_subscription_with_all_new_credentials(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test rebuild when all credentials are new (no HIDDEN_VALUE). Test rebuild when all credentials are new (no HIDDEN_VALUE).
@ -304,7 +304,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["api_secret"] == "completely-new-secret" assert subscribe_credentials["api_secret"] == "completely-new-secret"
def test_rebuild_trigger_subscription_with_all_hidden_values( def test_rebuild_trigger_subscription_with_all_hidden_values(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
@ -363,7 +363,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
@ -422,7 +422,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
def test_rebuild_trigger_subscription_rollback_on_error( def test_rebuild_trigger_subscription_rollback_on_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test that transaction is rolled back on error. Test that transaction is rolled back on error.
@ -470,12 +470,12 @@ class TestTriggerProviderService:
) )
# Verify subscription state was not changed (rolled back) # Verify subscription state was not changed (rolled back)
db.session.refresh(subscription) db_session_with_containers.refresh(subscription)
assert subscription.name == original_name assert subscription.name == original_name
assert subscription.parameters == original_parameters assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_subscription_not_found( def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error when subscription is not found. Test error when subscription is not found.
@ -501,7 +501,7 @@ class TestTriggerProviderService:
) )
def test_rebuild_trigger_subscription_name_uniqueness_check( def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test that name uniqueness is checked when updating name. Test that name uniqueness is checked when updating name.

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account from models import Account
@ -45,7 +46,7 @@ class TestWebConversationService:
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -90,7 +91,7 @@ class TestWebConversationService:
return app, account return app, account
def _create_test_end_user(self, db_session_with_containers, app): def _create_test_end_user(self, db_session_with_containers: Session, app):
""" """
Helper method to create a test end user for testing. Helper method to create a test end user for testing.
@ -111,14 +112,12 @@ class TestWebConversationService:
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
) )
from extensions.ext_database import db db_session_with_containers.add(end_user)
db_session_with_containers.commit()
db.session.add(end_user)
db.session.commit()
return end_user return end_user
def _create_test_conversation(self, db_session_with_containers, app, user, fake): def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake):
""" """
Helper method to create a test conversation for testing. Helper method to create a test conversation for testing.
@ -152,14 +151,14 @@ class TestWebConversationService:
is_deleted=False, is_deleted=False,
) )
from extensions.ext_database import db db_session_with_containers.add(conversation)
db_session_with_containers.commit()
db.session.add(conversation)
db.session.commit()
return conversation return conversation
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): def test_pagination_by_last_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful pagination by last ID with basic parameters. Test successful pagination by last ID with basic parameters.
""" """
@ -194,7 +193,7 @@ class TestWebConversationService:
assert result.data[1].updated_at >= result.data[2].updated_at assert result.data[1].updated_at >= result.data[2].updated_at
def test_pagination_by_last_id_with_pinned_filter( def test_pagination_by_last_id_with_pinned_filter(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination by last ID with pinned conversation filter. Test pagination by last ID with pinned conversation filter.
@ -222,11 +221,9 @@ class TestWebConversationService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(pinned_conversation1)
db_session_with_containers.add(pinned_conversation2)
db.session.add(pinned_conversation1) db_session_with_containers.commit()
db.session.add(pinned_conversation2)
db.session.commit()
# Test pagination with pinned filter # Test pagination with pinned filter
result = WebConversationService.pagination_by_last_id( result = WebConversationService.pagination_by_last_id(
@ -251,7 +248,7 @@ class TestWebConversationService:
assert set(returned_ids) == set(expected_ids) assert set(returned_ids) == set(expected_ids)
def test_pagination_by_last_id_with_unpinned_filter( def test_pagination_by_last_id_with_unpinned_filter(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination by last ID with unpinned conversation filter. Test pagination by last ID with unpinned conversation filter.
@ -273,10 +270,8 @@ class TestWebConversationService:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(pinned_conversation)
db_session_with_containers.commit()
db.session.add(pinned_conversation)
db.session.commit()
# Test pagination with unpinned filter # Test pagination with unpinned filter
result = WebConversationService.pagination_by_last_id( result = WebConversationService.pagination_by_last_id(
@ -303,7 +298,7 @@ class TestWebConversationService:
expected_unpinned_ids = [conv.id for conv in conversations[1:]] expected_unpinned_ids = [conv.id for conv in conversations[1:]]
assert set(returned_ids) == set(expected_unpinned_ids) assert set(returned_ids) == set(expected_unpinned_ids)
def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful pinning of a conversation. Test successful pinning of a conversation.
""" """
@ -317,10 +312,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account) WebConversationService.pin(app, conversation.id, account)
# Verify the conversation was pinned # Verify the conversation was pinned
from extensions.ext_database import db
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -336,7 +330,9 @@ class TestWebConversationService:
assert pinned_conversation.created_by_role == "account" assert pinned_conversation.created_by_role == "account"
assert pinned_conversation.created_by == account.id assert pinned_conversation.created_by == account.id
def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): def test_pin_conversation_already_pinned(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test pinning a conversation that is already pinned (should not create duplicate). Test pinning a conversation that is already pinned (should not create duplicate).
""" """
@ -353,9 +349,8 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account) WebConversationService.pin(app, conversation.id, account)
# Verify only one pinned conversation record exists # Verify only one pinned conversation record exists
from extensions.ext_database import db
pinned_conversations = db.session.scalars( pinned_conversations = db_session_with_containers.scalars(
select(PinnedConversation).where( select(PinnedConversation).where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -366,7 +361,9 @@ class TestWebConversationService:
assert len(pinned_conversations) == 1 assert len(pinned_conversations) == 1
def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): def test_pin_conversation_with_end_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test pinning a conversation with an end user. Test pinning a conversation with an end user.
""" """
@ -383,10 +380,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, end_user) WebConversationService.pin(app, conversation.id, end_user)
# Verify the conversation was pinned # Verify the conversation was pinned
from extensions.ext_database import db
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -402,7 +398,7 @@ class TestWebConversationService:
assert pinned_conversation.created_by_role == "end_user" assert pinned_conversation.created_by_role == "end_user"
assert pinned_conversation.created_by == end_user.id assert pinned_conversation.created_by == end_user.id
def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful unpinning of a conversation. Test successful unpinning of a conversation.
""" """
@ -416,10 +412,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account) WebConversationService.pin(app, conversation.id, account)
# Verify it was pinned # Verify it was pinned
from extensions.ext_database import db
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -436,7 +431,7 @@ class TestWebConversationService:
# Verify it was unpinned # Verify it was unpinned
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -448,7 +443,9 @@ class TestWebConversationService:
assert pinned_conversation is None assert pinned_conversation is None
def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): def test_unpin_conversation_not_pinned(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test unpinning a conversation that is not pinned (should not cause error). Test unpinning a conversation that is not pinned (should not cause error).
""" """
@ -462,10 +459,9 @@ class TestWebConversationService:
WebConversationService.unpin(app, conversation.id, account) WebConversationService.unpin(app, conversation.id, account)
# Verify no pinned conversation record exists # Verify no pinned conversation record exists
from extensions.ext_database import db
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -478,7 +474,7 @@ class TestWebConversationService:
assert pinned_conversation is None assert pinned_conversation is None
def test_pagination_by_last_id_user_required_error( def test_pagination_by_last_id_user_required_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test that pagination_by_last_id raises ValueError when user is None. Test that pagination_by_last_id raises ValueError when user is None.
@ -499,7 +495,7 @@ class TestWebConversationService:
sort_by="-updated_at", sort_by="-updated_at",
) )
def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test that pin method returns early when user is None. Test that pin method returns early when user is None.
""" """
@ -513,10 +509,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, None) WebConversationService.pin(app, conversation.id, None)
# Verify no pinned conversation was created # Verify no pinned conversation was created
from extensions.ext_database import db
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -526,7 +521,9 @@ class TestWebConversationService:
assert pinned_conversation is None assert pinned_conversation is None
def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): def test_unpin_conversation_user_none(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test that unpin method returns early when user is None. Test that unpin method returns early when user is None.
""" """
@ -540,10 +537,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account) WebConversationService.pin(app, conversation.id, account)
# Verify it was pinned # Verify it was pinned
from extensions.ext_database import db
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
@ -560,7 +556,7 @@ class TestWebConversationService:
# Verify the conversation is still pinned # Verify the conversation is still pinned
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db_session_with_containers.query(PinnedConversation)
.where( .where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from libs.password import hash_password from libs.password import hash_password
@ -45,7 +46,7 @@ class TestWebAppAuthService:
"enterprise_service": mock_enterprise_service, "enterprise_service": mock_enterprise_service,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -68,18 +69,16 @@ class TestWebAppAuthService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -88,15 +87,17 @@ class TestWebAppAuthService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_with_password(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Helper method to create a test account with password for testing. Helper method to create a test account with password for testing.
@ -131,18 +132,16 @@ class TestWebAppAuthService:
account.password = base64.b64encode(password_hash).decode() account.password = base64.b64encode(password_hash).decode()
account.password_salt = base64.b64encode(salt).decode() account.password_salt = base64.b64encode(salt).decode()
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -151,15 +150,17 @@ class TestWebAppAuthService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant, password return account, tenant, password
def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): def _create_test_app_and_site(
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant
):
""" """
Helper method to create a test app and site for testing. Helper method to create a test app and site for testing.
@ -188,10 +189,8 @@ class TestWebAppAuthService:
enable_api=True, enable_api=True,
) )
from extensions.ext_database import db db_session_with_containers.add(app)
db_session_with_containers.commit()
db.session.add(app)
db.session.commit()
# Create site # Create site
site = Site( site = Site(
@ -203,12 +202,12 @@ class TestWebAppAuthService:
status="normal", status="normal",
customize_token_strategy="not_allow", customize_token_strategy="not_allow",
) )
db.session.add(site) db_session_with_containers.add(site)
db.session.commit() db_session_with_containers.commit()
return app, site return app, site
def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful authentication with valid email and password. Test successful authentication with valid email and password.
@ -233,14 +232,15 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE assert result.status == AccountStatus.ACTIVE
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
assert result.password is not None assert result.password is not None
assert result.password_salt is not None assert result.password_salt is not None
def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_authenticate_account_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test authentication with non-existent email. Test authentication with non-existent email.
@ -262,7 +262,7 @@ class TestWebAppAuthService:
with pytest.raises(AccountNotFoundError): with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate(non_existent_email, "any_password") WebAppAuthService.authenticate(non_existent_email, "any_password")
def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test authentication with banned account. Test authentication with banned account.
@ -292,10 +292,8 @@ class TestWebAppAuthService:
account.password = base64.b64encode(password_hash).decode() account.password = base64.b64encode(password_hash).decode()
account.password_salt = base64.b64encode(salt).decode() account.password_salt = base64.b64encode(salt).decode()
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
with pytest.raises(AccountLoginError) as exc_info: with pytest.raises(AccountLoginError) as exc_info:
@ -303,7 +301,9 @@ class TestWebAppAuthService:
assert "Account is banned." in str(exc_info.value) assert "Account is banned." in str(exc_info.value)
def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): def test_authenticate_invalid_password(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test authentication with invalid password. Test authentication with invalid password.
@ -323,7 +323,7 @@ class TestWebAppAuthService:
assert "Invalid email or password." in str(exc_info.value) assert "Invalid email or password." in str(exc_info.value)
def test_authenticate_account_without_password( def test_authenticate_account_without_password(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test authentication for account without password. Test authentication for account without password.
@ -345,10 +345,8 @@ class TestWebAppAuthService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
with pytest.raises(AccountPasswordError) as exc_info: with pytest.raises(AccountPasswordError) as exc_info:
@ -356,7 +354,7 @@ class TestWebAppAuthService:
assert "Invalid email or password." in str(exc_info.value) assert "Invalid email or password." in str(exc_info.value)
def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful login and JWT token generation. Test successful login and JWT token generation.
@ -388,7 +386,9 @@ class TestWebAppAuthService:
assert call_args["auth_type"] == "internal" assert call_args["auth_type"] == "internal"
assert "exp" in call_args assert "exp" in call_args
def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_user_through_email_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful user retrieval through email. Test successful user retrieval through email.
@ -413,12 +413,13 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE assert result.status == AccountStatus.ACTIVE
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_get_user_through_email_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test user retrieval with non-existent email. Test user retrieval with non-existent email.
@ -435,7 +436,9 @@ class TestWebAppAuthService:
# Assert: Verify proper handling # Assert: Verify proper handling
assert result is None assert result is None
def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): def test_get_user_through_email_banned(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test user retrieval with banned account. Test user retrieval with banned account.
@ -456,10 +459,8 @@ class TestWebAppAuthService:
status=AccountStatus.BANNED, status=AccountStatus.BANNED,
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
with pytest.raises(Unauthorized) as exc_info: with pytest.raises(Unauthorized) as exc_info:
@ -468,7 +469,7 @@ class TestWebAppAuthService:
assert "Account is banned." in str(exc_info.value) assert "Account is banned." in str(exc_info.value)
def test_send_email_code_login_email_with_account( def test_send_email_code_login_email_with_account(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test sending email code login email with account. Test sending email code login email with account.
@ -509,7 +510,7 @@ class TestWebAppAuthService:
assert "code" in mail_call_args[1] assert "code" in mail_call_args[1]
def test_send_email_code_login_email_with_email_only( def test_send_email_code_login_email_with_email_only(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test sending email code login email with email only. Test sending email code login email with email only.
@ -549,7 +550,7 @@ class TestWebAppAuthService:
assert "code" in mail_call_args[1] assert "code" in mail_call_args[1]
def test_send_email_code_login_email_no_email_provided( def test_send_email_code_login_email_no_email_provided(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test sending email code login email without providing email. Test sending email code login email without providing email.
@ -566,7 +567,9 @@ class TestWebAppAuthService:
assert "Email must be provided." in str(exc_info.value) assert "Email must be provided." in str(exc_info.value)
def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_email_code_login_data_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful retrieval of email code login data. Test successful retrieval of email code login data.
@ -593,7 +596,9 @@ class TestWebAppAuthService:
"mock_token", "email_code_login" "mock_token", "email_code_login"
) )
def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): def test_get_email_code_login_data_no_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test email code login data retrieval when no data exists. Test email code login data retrieval when no data exists.
@ -617,7 +622,7 @@ class TestWebAppAuthService:
) )
def test_revoke_email_code_login_token_success( def test_revoke_email_code_login_token_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful revocation of email code login token. Test successful revocation of email code login token.
@ -636,7 +641,7 @@ class TestWebAppAuthService:
"mock_token", "email_code_login" "mock_token", "email_code_login"
) )
def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful end user creation. Test successful end user creation.
@ -668,14 +673,15 @@ class TestWebAppAuthService:
assert result.external_user_id == "enterpriseuser" assert result.external_user_id == "enterpriseuser"
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
assert result.created_at is not None assert result.created_at is not None
assert result.updated_at is not None assert result.updated_at is not None
def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_create_end_user_site_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test end user creation with non-existent site code. Test end user creation with non-existent site code.
@ -693,7 +699,9 @@ class TestWebAppAuthService:
assert "Site not found." in str(exc_info.value) assert "Site not found." in str(exc_info.value)
def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_create_end_user_app_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test end user creation when app is not found. Test end user creation when app is not found.
@ -708,10 +716,8 @@ class TestWebAppAuthService:
status="normal", status="normal",
) )
from extensions.ext_database import db db_session_with_containers.add(tenant)
db_session_with_containers.commit()
db.session.add(tenant)
db.session.commit()
site = Site( site = Site(
app_id="00000000-0000-0000-0000-000000000000", app_id="00000000-0000-0000-0000-000000000000",
@ -722,8 +728,8 @@ class TestWebAppAuthService:
status="normal", status="normal",
customize_token_strategy="not_allow", customize_token_strategy="not_allow",
) )
db.session.add(site) db_session_with_containers.add(site)
db.session.commit() db_session_with_containers.commit()
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info: with pytest.raises(NotFound) as exc_info:
@ -732,7 +738,7 @@ class TestWebAppAuthService:
assert "App not found." in str(exc_info.value) assert "App not found." in str(exc_info.value)
def test_is_app_require_permission_check_with_access_mode_private( def test_is_app_require_permission_check_with_access_mode_private(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test permission check requirement for private access mode. Test permission check requirement for private access mode.
@ -751,7 +757,7 @@ class TestWebAppAuthService:
assert result is True assert result is True
def test_is_app_require_permission_check_with_access_mode_public( def test_is_app_require_permission_check_with_access_mode_public(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test permission check requirement for public access mode. Test permission check requirement for public access mode.
@ -770,7 +776,7 @@ class TestWebAppAuthService:
assert result is False assert result is False
def test_is_app_require_permission_check_with_app_code( def test_is_app_require_permission_check_with_app_code(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test permission check requirement using app code. Test permission check requirement using app code.
@ -796,7 +802,7 @@ class TestWebAppAuthService:
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id")
def test_is_app_require_permission_check_no_parameters( def test_is_app_require_permission_check_no_parameters(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test permission check requirement with no parameters. Test permission check requirement with no parameters.
@ -814,7 +820,7 @@ class TestWebAppAuthService:
assert "Either app_code or app_id must be provided." in str(exc_info.value) assert "Either app_code or app_id must be provided." in str(exc_info.value)
def test_get_app_auth_type_with_access_mode_public( def test_get_app_auth_type_with_access_mode_public(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test app authentication type for public access mode. Test app authentication type for public access mode.
@ -833,7 +839,7 @@ class TestWebAppAuthService:
assert result == WebAppAuthType.PUBLIC assert result == WebAppAuthType.PUBLIC
def test_get_app_auth_type_with_access_mode_private( def test_get_app_auth_type_with_access_mode_private(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test app authentication type for private access mode. Test app authentication type for private access mode.
@ -851,7 +857,9 @@ class TestWebAppAuthService:
# Assert: Verify correct result # Assert: Verify correct result
assert result == WebAppAuthType.INTERNAL assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): def test_get_app_auth_type_with_app_code(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test app authentication type using app code. Test app authentication type using app code.
@ -878,7 +886,9 @@ class TestWebAppAuthService:
"enterprise_service" "enterprise_service"
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): def test_get_app_auth_type_no_parameters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test app authentication type with no parameters. Test app authentication type with no parameters.

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
@ -48,7 +49,7 @@ class TestWorkflowAppService:
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -96,7 +97,7 @@ class TestWorkflowAppService:
return app, account return app, account
def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test tenant and account for testing. Helper method to create a test tenant and account for testing.
@ -126,7 +127,7 @@ class TestWorkflowAppService:
return tenant, account return tenant, account
def _create_test_app(self, db_session_with_containers, tenant, account): def _create_test_app(self, db_session_with_containers: Session, tenant, account):
""" """
Helper method to create a test app for testing. Helper method to create a test app for testing.
@ -160,7 +161,7 @@ class TestWorkflowAppService:
return app return app
def _create_test_workflow_data(self, db_session_with_containers, app, account): def _create_test_workflow_data(self, db_session_with_containers: Session, app, account):
""" """
Helper method to create test workflow data for testing. Helper method to create test workflow data for testing.
@ -174,8 +175,6 @@ class TestWorkflowAppService:
""" """
fake = Faker() fake = Faker()
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -188,8 +187,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create workflow run # Create workflow run
workflow_run = WorkflowRun( workflow_run = WorkflowRun(
@ -212,8 +211,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
finished_at=datetime.now(UTC), finished_at=datetime.now(UTC),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
# Create workflow app log # Create workflow app log
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
@ -227,13 +226,13 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) workflow_app_log.created_at = datetime.now(UTC)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
return workflow, workflow_run, workflow_app_log return workflow, workflow_run, workflow_app_log
def test_get_paginate_workflow_app_logs_basic_success( def test_get_paginate_workflow_app_logs_basic_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful pagination of workflow app logs with basic parameters. Test successful pagination of workflow app logs with basic parameters.
@ -268,13 +267,12 @@ class TestWorkflowAppService:
assert log_entry.workflow_run_id == workflow_run.id assert log_entry.workflow_run_id == workflow_run.id
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(workflow_app_log) db_session_with_containers.refresh(workflow_app_log)
assert workflow_app_log.id is not None assert workflow_app_log.id is not None
def test_get_paginate_workflow_app_logs_with_keyword_search( def test_get_paginate_workflow_app_logs_with_keyword_search(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with keyword search functionality. Test workflow app logs pagination with keyword search functionality.
@ -287,11 +285,10 @@ class TestWorkflowAppService:
) )
# Update workflow run with searchable content # Update workflow run with searchable content
from extensions.ext_database import db
workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"}) workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"})
workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"}) workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"})
db.session.commit() db_session_with_containers.commit()
# Act: Execute the method under test with keyword search # Act: Execute the method under test with keyword search
service = WorkflowAppService() service = WorkflowAppService()
@ -317,7 +314,7 @@ class TestWorkflowAppService:
assert len(result_no_match["data"]) == 0 assert len(result_no_match["data"]) == 0
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
r""" r"""
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
@ -332,8 +329,6 @@ class TestWorkflowAppService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
from extensions.ext_database import db
service = WorkflowAppService() service = WorkflowAppService()
# Test 1: Search with % character # Test 1: Search with % character
@ -353,8 +348,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
) )
db.session.add(workflow_run_1) db_session_with_containers.add(workflow_run_1)
db.session.flush() db_session_with_containers.flush()
workflow_app_log_1 = WorkflowAppLog( workflow_app_log_1 = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -367,8 +362,8 @@ class TestWorkflowAppService:
) )
workflow_app_log_1.id = str(uuid.uuid4()) workflow_app_log_1.id = str(uuid.uuid4())
workflow_app_log_1.created_at = datetime.now(UTC) workflow_app_log_1.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_1) db_session_with_containers.add(workflow_app_log_1)
db.session.commit() db_session_with_containers.commit()
result = service.get_paginate_workflow_app_logs( result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
@ -395,8 +390,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
) )
db.session.add(workflow_run_2) db_session_with_containers.add(workflow_run_2)
db.session.flush() db_session_with_containers.flush()
workflow_app_log_2 = WorkflowAppLog( workflow_app_log_2 = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -409,8 +404,8 @@ class TestWorkflowAppService:
) )
workflow_app_log_2.id = str(uuid.uuid4()) workflow_app_log_2.id = str(uuid.uuid4())
workflow_app_log_2.created_at = datetime.now(UTC) workflow_app_log_2.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_2) db_session_with_containers.add(workflow_app_log_2)
db.session.commit() db_session_with_containers.commit()
result = service.get_paginate_workflow_app_logs( result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
@ -437,8 +432,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
) )
db.session.add(workflow_run_4) db_session_with_containers.add(workflow_run_4)
db.session.flush() db_session_with_containers.flush()
workflow_app_log_4 = WorkflowAppLog( workflow_app_log_4 = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -451,8 +446,8 @@ class TestWorkflowAppService:
) )
workflow_app_log_4.id = str(uuid.uuid4()) workflow_app_log_4.id = str(uuid.uuid4())
workflow_app_log_4.created_at = datetime.now(UTC) workflow_app_log_4.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_4) db_session_with_containers.add(workflow_app_log_4)
db.session.commit() db_session_with_containers.commit()
result = service.get_paginate_workflow_app_logs( result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
@ -467,7 +462,7 @@ class TestWorkflowAppService:
assert workflow_run_4.id not in found_run_ids assert workflow_run_4.id not in found_run_ids
def test_get_paginate_workflow_app_logs_with_status_filter( def test_get_paginate_workflow_app_logs_with_status_filter(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with status filtering. Test workflow app logs pagination with status filtering.
@ -476,8 +471,6 @@ class TestWorkflowAppService:
fake = Faker() fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -490,8 +483,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create workflow runs with different statuses # Create workflow runs with different statuses
statuses = ["succeeded", "failed", "running", "stopped"] statuses = ["succeeded", "failed", "running", "stopped"]
@ -519,8 +512,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i), created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None,
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -533,8 +526,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
workflow_runs.append(workflow_run) workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log) workflow_app_logs.append(workflow_app_log)
@ -568,7 +561,7 @@ class TestWorkflowAppService:
assert result_running["data"][0].workflow_run.status == "running" assert result_running["data"][0].workflow_run.status == "running"
def test_get_paginate_workflow_app_logs_with_time_filtering( def test_get_paginate_workflow_app_logs_with_time_filtering(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with time-based filtering. Test workflow app logs pagination with time-based filtering.
@ -577,8 +570,6 @@ class TestWorkflowAppService:
fake = Faker() fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -591,8 +582,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create workflow runs with different timestamps # Create workflow runs with different timestamps
base_time = datetime.now(UTC) base_time = datetime.now(UTC)
@ -627,8 +618,8 @@ class TestWorkflowAppService:
created_at=timestamp, created_at=timestamp,
finished_at=timestamp + timedelta(minutes=1), finished_at=timestamp + timedelta(minutes=1),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -641,8 +632,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = timestamp workflow_app_log.created_at = timestamp
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
workflow_runs.append(workflow_run) workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log) workflow_app_logs.append(workflow_app_log)
@ -682,7 +673,7 @@ class TestWorkflowAppService:
assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago
def test_get_paginate_workflow_app_logs_with_pagination( def test_get_paginate_workflow_app_logs_with_pagination(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with different page sizes and limits. Test workflow app logs pagination with different page sizes and limits.
@ -691,8 +682,6 @@ class TestWorkflowAppService:
fake = Faker() fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -705,8 +694,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create 25 workflow runs and logs # Create 25 workflow runs and logs
total_logs = 25 total_logs = 25
@ -734,8 +723,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i), created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -748,8 +737,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
workflow_runs.append(workflow_run) workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log) workflow_app_logs.append(workflow_app_log)
@ -798,7 +787,7 @@ class TestWorkflowAppService:
assert len(result_large_limit["data"]) == total_logs assert len(result_large_limit["data"]) == total_logs
def test_get_paginate_workflow_app_logs_with_user_role_filtering( def test_get_paginate_workflow_app_logs_with_user_role_filtering(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with user role and session filtering. Test workflow app logs pagination with user role and session filtering.
@ -807,8 +796,6 @@ class TestWorkflowAppService:
fake = Faker() fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -821,8 +808,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create end user # Create end user
end_user = EndUser( end_user = EndUser(
@ -835,8 +822,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
updated_at=datetime.now(UTC), updated_at=datetime.now(UTC),
) )
db.session.add(end_user) db_session_with_containers.add(end_user)
db.session.commit() db_session_with_containers.commit()
# Create workflow runs and logs for both account and end user # Create workflow runs and logs for both account and end user
workflow_runs = [] workflow_runs = []
@ -864,8 +851,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i), created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -878,8 +865,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
workflow_runs.append(workflow_run) workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log) workflow_app_logs.append(workflow_app_log)
@ -906,8 +893,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i + 10), created_at=datetime.now(UTC) + timedelta(minutes=i + 10),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -920,8 +907,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
workflow_runs.append(workflow_run) workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log) workflow_app_logs.append(workflow_app_log)
@ -994,7 +981,7 @@ class TestWorkflowAppService:
assert "Account not found" in str(exc_info.value) assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( def test_get_paginate_workflow_app_logs_with_uuid_keyword_search(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with UUID keyword search functionality. Test workflow app logs pagination with UUID keyword search functionality.
@ -1003,8 +990,6 @@ class TestWorkflowAppService:
fake = Faker() fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -1017,8 +1002,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create workflow run with specific UUID # Create workflow run with specific UUID
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
@ -1042,8 +1027,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
finished_at=datetime.now(UTC) + timedelta(minutes=1), finished_at=datetime.now(UTC) + timedelta(minutes=1),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
# Create workflow app log # Create workflow app log
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
@ -1057,8 +1042,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) workflow_app_log.created_at = datetime.now(UTC)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
# Act & Assert: Test UUID keyword search # Act & Assert: Test UUID keyword search
service = WorkflowAppService() service = WorkflowAppService()
@ -1085,7 +1070,7 @@ class TestWorkflowAppService:
assert result_invalid_uuid["total"] == 0 assert result_invalid_uuid["total"] == 0
def test_get_paginate_workflow_app_logs_with_edge_cases( def test_get_paginate_workflow_app_logs_with_edge_cases(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with edge cases and boundary conditions. Test workflow app logs pagination with edge cases and boundary conditions.
@ -1094,8 +1079,6 @@ class TestWorkflowAppService:
fake = Faker() fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -1108,8 +1091,8 @@ class TestWorkflowAppService:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create workflow run with edge case data # Create workflow run with edge case data
workflow_run = WorkflowRun( workflow_run = WorkflowRun(
@ -1132,8 +1115,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
finished_at=datetime.now(UTC), finished_at=datetime.now(UTC),
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
# Create workflow app log # Create workflow app log
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog(
@ -1147,8 +1130,8 @@ class TestWorkflowAppService:
) )
workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) workflow_app_log.created_at = datetime.now(UTC)
db.session.add(workflow_app_log) db_session_with_containers.add(workflow_app_log)
db.session.commit() db_session_with_containers.commit()
# Act & Assert: Test edge cases # Act & Assert: Test edge cases
service = WorkflowAppService() service = WorkflowAppService()
@ -1185,7 +1168,7 @@ class TestWorkflowAppService:
assert result_high_page["has_more"] is False assert result_high_page["has_more"] is False
def test_get_paginate_workflow_app_logs_with_empty_results( def test_get_paginate_workflow_app_logs_with_empty_results(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with empty results and no data scenarios. Test workflow app logs pagination with empty results and no data scenarios.
@ -1252,7 +1235,7 @@ class TestWorkflowAppService:
assert "Account not found" in str(exc_info.value) assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_complex_query_combinations( def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with complex query combinations. Test workflow app logs pagination with complex query combinations.
@ -1352,7 +1335,7 @@ class TestWorkflowAppService:
assert len(result_time_status_limit["data"]) <= 2 assert len(result_time_status_limit["data"]) <= 2
def test_get_paginate_workflow_app_logs_with_large_dataset_performance( def test_get_paginate_workflow_app_logs_with_large_dataset_performance(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with large dataset for performance validation. Test workflow app logs pagination with large dataset for performance validation.
@ -1444,7 +1427,7 @@ class TestWorkflowAppService:
assert result_last_page["page"] == 3 assert result_last_page["page"] == 3
def test_get_paginate_workflow_app_logs_with_tenant_isolation( def test_get_paginate_workflow_app_logs_with_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow app logs pagination with proper tenant isolation. Test workflow app logs pagination with proper tenant isolation.

View File

@ -1,5 +1,6 @@
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.segments import StringSegment from dify_graph.variables.segments import StringSegment
@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService:
# WorkflowDraftVariableService doesn't have external dependencies that need mocking # WorkflowDraftVariableService doesn't have external dependencies that need mocking
return {} return {}
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None):
""" """
Helper method to create a test app with realistic data for testing. Helper method to create a test app with realistic data for testing.
@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService:
app.created_by = fake.uuid4() app.created_by = fake.uuid4()
app.updated_by = app.created_by app.updated_by = app.created_by
from extensions.ext_database import db db_session_with_containers.add(app)
db_session_with_containers.commit()
db.session.add(app)
db.session.commit()
return app return app
def _create_test_workflow(self, db_session_with_containers, app, fake=None): def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None):
""" """
Helper method to create a test workflow associated with an app. Helper method to create a test workflow associated with an app.
@ -110,15 +109,14 @@ class TestWorkflowDraftVariableService:
conversation_variables=[], conversation_variables=[],
rag_pipeline_variables=[], rag_pipeline_variables=[],
) )
from extensions.ext_database import db
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
return workflow return workflow
def _create_test_variable( def _create_test_variable(
self, self,
db_session_with_containers, db_session_with_containers: Session,
app_id, app_id,
node_id, node_id,
name, name,
@ -174,13 +172,12 @@ class TestWorkflowDraftVariableService:
visible=True, visible=True,
editable=True, editable=True,
) )
from extensions.ext_database import db
db.session.add(variable) db_session_with_containers.add(variable)
db.session.commit() db_session_with_containers.commit()
return variable return variable
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test getting a single variable by ID successfully. Test getting a single variable by ID successfully.
@ -202,7 +199,7 @@ class TestWorkflowDraftVariableService:
assert retrieved_variable.app_id == app.id assert retrieved_variable.app_id == app.id
assert retrieved_variable.get_value().value == test_value.value assert retrieved_variable.get_value().value == test_value.value
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test getting a variable that doesn't exist. Test getting a variable that doesn't exist.
@ -217,7 +214,7 @@ class TestWorkflowDraftVariableService:
assert retrieved_variable is None assert retrieved_variable is None
def test_get_draft_variables_by_selectors_success( def test_get_draft_variables_by_selectors_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting variables by selectors successfully. Test getting variables by selectors successfully.
@ -268,7 +265,7 @@ class TestWorkflowDraftVariableService:
assert var.get_value().value == var3_value.value assert var.get_value().value == var3_value.value
def test_list_variables_without_values_success( def test_list_variables_without_values_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test listing variables without values successfully with pagination. Test listing variables without values successfully with pagination.
@ -300,7 +297,7 @@ class TestWorkflowDraftVariableService:
assert var.name is not None assert var.name is not None
assert var.app_id == app.id assert var.app_id == app.id
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test listing variables for a specific node successfully. Test listing variables for a specific node successfully.
@ -352,7 +349,9 @@ class TestWorkflowDraftVariableService:
assert "var2" in var_names assert "var2" in var_names
assert "var3" not in var_names assert "var3" not in var_names
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): def test_list_conversation_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test listing conversation variables successfully. Test listing conversation variables successfully.
@ -393,7 +392,7 @@ class TestWorkflowDraftVariableService:
assert "conv_var2" in var_names assert "conv_var2" in var_names
assert "sys_var" not in var_names assert "sys_var" not in var_names
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test updating a variable's name and value successfully. Test updating a variable's name and value successfully.
@ -418,14 +417,15 @@ class TestWorkflowDraftVariableService:
assert updated_variable.name == "new_name" assert updated_variable.name == "new_name"
assert updated_variable.get_value().value == new_value.value assert updated_variable.get_value().value == new_value.value
assert updated_variable.last_edited_at is not None assert updated_variable.last_edited_at is not None
from extensions.ext_database import db
db.session.refresh(variable) db_session_with_containers.refresh(variable)
assert variable.name == "new_name" assert variable.name == "new_name"
assert variable.get_value().value == new_value.value assert variable.get_value().value == new_value.value
assert variable.last_edited_at is not None assert variable.last_edited_at is not None
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): def test_update_variable_not_editable(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test that updating a non-editable variable raises an exception. Test that updating a non-editable variable raises an exception.
@ -445,17 +445,18 @@ class TestWorkflowDraftVariableService:
node_execution_id=fake.uuid4(), node_execution_id=fake.uuid4(),
editable=False, # Set as non-editable editable=False, # Set as non-editable
) )
from extensions.ext_database import db
db.session.add(variable) db_session_with_containers.add(variable)
db.session.commit() db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers) service = WorkflowDraftVariableService(db_session_with_containers)
with pytest.raises(UpdateNotSupportedError) as exc_info: with pytest.raises(UpdateNotSupportedError) as exc_info:
service.update_variable(variable, name="new_name", value=new_value) service.update_variable(variable, name="new_name", value=new_value)
assert "variable not support updating" in str(exc_info.value) assert "variable not support updating" in str(exc_info.value)
assert variable.id in str(exc_info.value) assert variable.id in str(exc_info.value)
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): def test_reset_conversation_variable_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test resetting conversation variable successfully. Test resetting conversation variable successfully.
@ -476,9 +477,8 @@ class TestWorkflowDraftVariableService:
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
) )
workflow.conversation_variables = [conv_var] workflow.conversation_variables = [conv_var]
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
modified_value = StringSegment(value=fake.word()) modified_value = StringSegment(value=fake.word())
variable = self._create_test_variable( variable = self._create_test_variable(
db_session_with_containers, db_session_with_containers,
@ -489,17 +489,17 @@ class TestWorkflowDraftVariableService:
fake=fake, fake=fake,
) )
variable.last_edited_at = fake.date_time() variable.last_edited_at = fake.date_time()
db.session.commit() db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers) service = WorkflowDraftVariableService(db_session_with_containers)
reset_variable = service.reset_variable(workflow, variable) reset_variable = service.reset_variable(workflow, variable)
assert reset_variable is not None assert reset_variable is not None
assert reset_variable.get_value().value == "default_value" assert reset_variable.get_value().value == "default_value"
assert reset_variable.last_edited_at is None assert reset_variable.last_edited_at is None
db.session.refresh(variable) db_session_with_containers.refresh(variable)
assert variable.get_value().value == "default_value" assert variable.get_value().value == "default_value"
assert variable.last_edited_at is None assert variable.last_edited_at is None
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test deleting a single variable successfully. Test deleting a single variable successfully.
@ -513,14 +513,15 @@ class TestWorkflowDraftVariableService:
variable = self._create_test_variable( variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
) )
from extensions.ext_database import db
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
service = WorkflowDraftVariableService(db_session_with_containers) service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_variable(variable) service.delete_variable(variable)
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_workflow_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test deleting all variables for a workflow successfully. Test deleting all variables for a workflow successfully.
@ -550,20 +551,25 @@ class TestWorkflowDraftVariableService:
other_value, other_value,
fake=fake, fake=fake,
) )
from extensions.ext_database import db
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() other_app_variables = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
)
assert len(app_variables) == 3 assert len(app_variables) == 3
assert len(other_app_variables) == 1 assert len(other_app_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers) service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_workflow_variables(app.id) service.delete_workflow_variables(app.id)
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() other_app_variables_after = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
)
assert len(app_variables_after) == 0 assert len(app_variables_after) == 0
assert len(other_app_variables_after) == 1 assert len(other_app_variables_after) == 1
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_node_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test deleting all variables for a specific node successfully. Test deleting all variables for a specific node successfully.
@ -605,14 +611,15 @@ class TestWorkflowDraftVariableService:
conv_value, conv_value,
fake=fake, fake=fake,
) )
from extensions.ext_database import db
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() target_node_variables = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
)
other_node_variables = ( other_node_variables = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
) )
conv_variables = ( conv_variables = (
db.session.query(WorkflowDraftVariable) db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all() .all()
) )
@ -622,13 +629,13 @@ class TestWorkflowDraftVariableService:
service = WorkflowDraftVariableService(db_session_with_containers) service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_node_variables(app.id, node_id) service.delete_node_variables(app.id, node_id)
target_node_variables_after = ( target_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
) )
other_node_variables_after = ( other_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
) )
conv_variables_after = ( conv_variables_after = (
db.session.query(WorkflowDraftVariable) db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all() .all()
) )
@ -637,7 +644,7 @@ class TestWorkflowDraftVariableService:
assert len(conv_variables_after) == 1 assert len(conv_variables_after) == 1
def test_prefill_conversation_variable_default_values_success( def test_prefill_conversation_variable_default_values_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test prefill conversation variable default values successfully. Test prefill conversation variable default values successfully.
@ -665,13 +672,12 @@ class TestWorkflowDraftVariableService:
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
) )
workflow.conversation_variables = [conv_var1, conv_var2] workflow.conversation_variables = [conv_var1, conv_var2]
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers) service = WorkflowDraftVariableService(db_session_with_containers)
service.prefill_conversation_variable_default_values(workflow) service.prefill_conversation_variable_default_values(workflow)
draft_variables = ( draft_variables = (
db.session.query(WorkflowDraftVariable) db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all() .all()
) )
@ -686,7 +692,7 @@ class TestWorkflowDraftVariableService:
assert var.get_variable_type() == DraftVariableType.CONVERSATION assert var.get_variable_type() == DraftVariableType.CONVERSATION
def test_get_conversation_id_from_draft_variable_success( def test_get_conversation_id_from_draft_variable_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting conversation ID from draft variable successfully. Test getting conversation ID from draft variable successfully.
@ -713,7 +719,7 @@ class TestWorkflowDraftVariableService:
assert retrieved_conv_id == conversation_id assert retrieved_conv_id == conversation_id
def test_get_conversation_id_from_draft_variable_not_found( def test_get_conversation_id_from_draft_variable_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting conversation ID when it doesn't exist. Test getting conversation ID when it doesn't exist.
@ -728,7 +734,9 @@ class TestWorkflowDraftVariableService:
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
assert retrieved_conv_id is None assert retrieved_conv_id is None
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): def test_list_system_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test listing system variables successfully. Test listing system variables successfully.
@ -775,7 +783,9 @@ class TestWorkflowDraftVariableService:
assert "sys_var2" in var_names assert "sys_var2" in var_names
assert "conv_var" not in var_names assert "conv_var" not in var_names
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_variable_by_name_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test getting variables by name successfully for different types. Test getting variables by name successfully for different types.
@ -822,7 +832,9 @@ class TestWorkflowDraftVariableService:
assert retrieved_node_var.name == "test_node_var" assert retrieved_node_var.name == "test_node_var"
assert retrieved_node_var.node_id == "test_node" assert retrieved_node_var.node_id == "test_node"
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_get_variable_by_name_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test getting variables by name when they don't exist. Test getting variables by name when they don't exist.

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.model import ( from models.model import (
@ -48,7 +49,7 @@ class TestWorkflowRunService:
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -94,7 +95,7 @@ class TestWorkflowRunService:
return app, account return app, account
def _create_test_workflow_run( def _create_test_workflow_run(
self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0 self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0
): ):
""" """
Helper method to create a test workflow run for testing. Helper method to create a test workflow run for testing.
@ -110,8 +111,6 @@ class TestWorkflowRunService:
""" """
fake = Faker() fake = Faker()
from extensions.ext_database import db
# Create workflow run with offset timestamp # Create workflow run with offset timestamp
base_time = datetime.now(UTC) base_time = datetime.now(UTC)
created_time = base_time - timedelta(minutes=offset_minutes) created_time = base_time - timedelta(minutes=offset_minutes)
@ -136,12 +135,12 @@ class TestWorkflowRunService:
finished_at=created_time, finished_at=created_time,
) )
db.session.add(workflow_run) db_session_with_containers.add(workflow_run)
db.session.commit() db_session_with_containers.commit()
return workflow_run return workflow_run
def _create_test_message(self, db_session_with_containers, app, account, workflow_run): def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run):
""" """
Helper method to create a test message for testing. Helper method to create a test message for testing.
@ -156,8 +155,6 @@ class TestWorkflowRunService:
""" """
fake = Faker() fake = Faker()
from extensions.ext_database import db
# Create conversation first (required for message) # Create conversation first (required for message)
from models.model import Conversation from models.model import Conversation
@ -170,8 +167,8 @@ class TestWorkflowRunService:
from_source=CreatorUserRole.ACCOUNT, from_source=CreatorUserRole.ACCOUNT,
from_account_id=account.id, from_account_id=account.id,
) )
db.session.add(conversation) db_session_with_containers.add(conversation)
db.session.commit() db_session_with_containers.commit()
# Create message # Create message
message = Message() message = Message()
@ -193,12 +190,14 @@ class TestWorkflowRunService:
message.workflow_run_id = workflow_run.id message.workflow_run_id = workflow_run.id
message.inputs = {"input": "test input"} message.inputs = {"input": "test input"}
db.session.add(message) db_session_with_containers.add(message)
db.session.commit() db_session_with_containers.commit()
return message return message
def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_paginate_workflow_runs_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful pagination of workflow runs with debugging trigger. Test successful pagination of workflow runs with debugging trigger.
@ -239,7 +238,7 @@ class TestWorkflowRunService:
assert workflow_run.tenant_id == app.tenant_id assert workflow_run.tenant_id == app.tenant_id
def test_get_paginate_workflow_runs_with_last_id( def test_get_paginate_workflow_runs_with_last_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination of workflow runs with last_id parameter. Test pagination of workflow runs with last_id parameter.
@ -282,7 +281,7 @@ class TestWorkflowRunService:
assert workflow_run.tenant_id == app.tenant_id assert workflow_run.tenant_id == app.tenant_id
def test_get_paginate_workflow_runs_default_limit( def test_get_paginate_workflow_runs_default_limit(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test pagination of workflow runs with default limit. Test pagination of workflow runs with default limit.
@ -320,7 +319,7 @@ class TestWorkflowRunService:
assert workflow_run_result.tenant_id == app.tenant_id assert workflow_run_result.tenant_id == app.tenant_id
def test_get_paginate_advanced_chat_workflow_runs_success( def test_get_paginate_advanced_chat_workflow_runs_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful pagination of advanced chat workflow runs with message information. Test successful pagination of advanced chat workflow runs with message information.
@ -365,7 +364,7 @@ class TestWorkflowRunService:
assert workflow_run.app_id == app.id assert workflow_run.app_id == app.id
assert workflow_run.tenant_id == app.tenant_id assert workflow_run.tenant_id == app.tenant_id
def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of workflow run by ID. Test successful retrieval of workflow run by ID.
@ -395,7 +394,7 @@ class TestWorkflowRunService:
assert result.type == "chat" assert result.type == "chat"
assert result.version == "1.0.0" assert result.version == "1.0.0"
def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test workflow run retrieval when run ID does not exist. Test workflow run retrieval when run ID does not exist.
@ -419,7 +418,7 @@ class TestWorkflowRunService:
assert result is None assert result is None
def test_get_workflow_run_node_executions_success( def test_get_workflow_run_node_executions_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful retrieval of workflow run node executions. Test successful retrieval of workflow run node executions.
@ -438,7 +437,6 @@ class TestWorkflowRunService:
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Create node executions # Create node executions
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionModel
node_executions = [] node_executions = []
@ -462,7 +460,7 @@ class TestWorkflowRunService:
created_by=account.id, created_by=account.id,
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
) )
db.session.add(node_execution) db_session_with_containers.add(node_execution)
node_executions.append(node_execution) node_executions.append(node_execution)
paused_node_execution = WorkflowNodeExecutionModel( paused_node_execution = WorkflowNodeExecutionModel(
@ -484,9 +482,9 @@ class TestWorkflowRunService:
created_by=account.id, created_by=account.id,
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
) )
db.session.add(paused_node_execution) db_session_with_containers.add(paused_node_execution)
db.session.commit() db_session_with_containers.commit()
# Act: Execute the method under test # Act: Execute the method under test
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
@ -509,7 +507,7 @@ class TestWorkflowRunService:
assert node_execution.node_id.startswith("node_") assert node_execution.node_id.startswith("node_")
def test_get_workflow_run_node_executions_empty( def test_get_workflow_run_node_executions_empty(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting node executions for a workflow run with no executions. Test getting node executions for a workflow run with no executions.
@ -560,7 +558,7 @@ class TestWorkflowRunService:
assert len(result) == 0 assert len(result) == 0
def test_get_workflow_run_node_executions_invalid_workflow_run_id( def test_get_workflow_run_node_executions_invalid_workflow_run_id(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting node executions with invalid workflow run ID. Test getting node executions with invalid workflow run ID.
@ -611,7 +609,7 @@ class TestWorkflowRunService:
assert len(result) == 0 assert len(result) == 0
def test_get_workflow_run_node_executions_database_error( def test_get_workflow_run_node_executions_database_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test getting node executions when database encounters an error. Test getting node executions when database encounters an error.
@ -662,7 +660,7 @@ class TestWorkflowRunService:
) )
def test_get_workflow_run_node_executions_end_user( def test_get_workflow_run_node_executions_end_user(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test node execution retrieval for end user. Test node execution retrieval for end user.
@ -680,7 +678,6 @@ class TestWorkflowRunService:
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Create end user # Create end user
from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
end_user = EndUser( end_user = EndUser(
@ -692,8 +689,8 @@ class TestWorkflowRunService:
external_user_id=str(uuid.uuid4()), external_user_id=str(uuid.uuid4()),
name=fake.name(), name=fake.name(),
) )
db.session.add(end_user) db_session_with_containers.add(end_user)
db.session.commit() db_session_with_containers.commit()
# Create node execution # Create node execution
from models.workflow import WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionModel
@ -717,8 +714,8 @@ class TestWorkflowRunService:
created_by=end_user.id, created_by=end_user.id,
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
) )
db.session.add(node_execution) db_session_with_containers.add(node_execution)
db.session.commit() db_session_with_containers.commit()
# Act: Execute the method under test # Act: Execute the method under test
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()

View File

@ -10,6 +10,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models import Account, App, Workflow from models import Account, App, Workflow
from models.model import AppMode from models.model import AppMode
@ -32,7 +33,7 @@ class TestWorkflowService:
and realistic testing environment with actual database interactions. and realistic testing environment with actual database interactions.
""" """
def _create_test_account(self, db_session_with_containers, fake=None): def _create_test_account(self, db_session_with_containers: Session, fake=None):
""" """
Helper method to create a test account with realistic data. Helper method to create a test account with realistic data.
@ -67,18 +68,16 @@ class TestWorkflowService:
tenant.created_at = fake.date_time_this_year() tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at tenant.updated_at = tenant.created_at
from extensions.ext_database import db db_session_with_containers.add(tenant)
db_session_with_containers.add(account)
db.session.add(tenant) db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Set the current tenant for the account # Set the current tenant for the account
account.current_tenant = tenant account.current_tenant = tenant
return account return account
def _create_test_app(self, db_session_with_containers, fake=None): def _create_test_app(self, db_session_with_containers: Session, fake=None):
""" """
Helper method to create a test app with realistic data. Helper method to create a test app with realistic data.
@ -106,13 +105,11 @@ class TestWorkflowService:
) )
app.updated_by = app.created_by app.updated_by = app.created_by
from extensions.ext_database import db db_session_with_containers.add(app)
db_session_with_containers.commit()
db.session.add(app)
db.session.commit()
return app return app
def _create_test_workflow(self, db_session_with_containers, app, account, fake=None): def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None):
""" """
Helper method to create a test workflow associated with an app. Helper method to create a test workflow associated with an app.
@ -141,13 +138,11 @@ class TestWorkflowService:
conversation_variables=[], conversation_variables=[],
) )
from extensions.ext_database import db db_session_with_containers.add(workflow)
db_session_with_containers.commit()
db.session.add(workflow)
db.session.commit()
return workflow return workflow
def test_get_node_last_run_success(self, db_session_with_containers): def test_get_node_last_run_success(self, db_session_with_containers: Session):
""" """
Test successful retrieval of the most recent execution for a specific node. Test successful retrieval of the most recent execution for a specific node.
@ -180,10 +175,8 @@ class TestWorkflowService:
node_execution.created_by = account.id # Required field node_execution.created_by = account.id # Required field
node_execution.created_at = fake.date_time_this_year() node_execution.created_at = fake.date_time_this_year()
from extensions.ext_database import db db_session_with_containers.add(node_execution)
db_session_with_containers.commit()
db.session.add(node_execution)
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -196,7 +189,7 @@ class TestWorkflowService:
assert result.workflow_id == workflow.id assert result.workflow_id == workflow.id
assert result.status == "succeeded" assert result.status == "succeeded"
def test_get_node_last_run_not_found(self, db_session_with_containers): def test_get_node_last_run_not_found(self, db_session_with_containers: Session):
""" """
Test retrieval when no execution record exists for the specified node. Test retrieval when no execution record exists for the specified node.
@ -217,7 +210,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is None assert result is None
def test_is_workflow_exist_true(self, db_session_with_containers): def test_is_workflow_exist_true(self, db_session_with_containers: Session):
""" """
Test workflow existence check when a draft workflow exists. Test workflow existence check when a draft workflow exists.
@ -238,7 +231,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is True assert result is True
def test_is_workflow_exist_false(self, db_session_with_containers): def test_is_workflow_exist_false(self, db_session_with_containers: Session):
""" """
Test workflow existence check when no draft workflow exists. Test workflow existence check when no draft workflow exists.
@ -258,7 +251,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is False assert result is False
def test_get_draft_workflow_success(self, db_session_with_containers): def test_get_draft_workflow_success(self, db_session_with_containers: Session):
""" """
Test successful retrieval of a draft workflow. Test successful retrieval of a draft workflow.
@ -284,7 +277,7 @@ class TestWorkflowService:
assert result.app_id == app.id assert result.app_id == app.id
assert result.tenant_id == app.tenant_id assert result.tenant_id == app.tenant_id
def test_get_draft_workflow_not_found(self, db_session_with_containers): def test_get_draft_workflow_not_found(self, db_session_with_containers: Session):
""" """
Test draft workflow retrieval when no draft workflow exists. Test draft workflow retrieval when no draft workflow exists.
@ -304,7 +297,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is None assert result is None
def test_get_published_workflow_by_id_success(self, db_session_with_containers): def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session):
""" """
Test successful retrieval of a published workflow by ID. Test successful retrieval of a published workflow by ID.
@ -321,9 +314,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Published version workflow.version = "2024.01.01.001" # Published version
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -336,7 +327,7 @@ class TestWorkflowService:
assert result.version != Workflow.VERSION_DRAFT assert result.version != Workflow.VERSION_DRAFT
assert result.app_id == app.id assert result.app_id == app.id
def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers): def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session):
""" """
Test error when trying to retrieve a draft workflow as published. Test error when trying to retrieve a draft workflow as published.
@ -359,7 +350,7 @@ class TestWorkflowService:
with pytest.raises(IsDraftWorkflowError): with pytest.raises(IsDraftWorkflowError):
workflow_service.get_published_workflow_by_id(app, workflow.id) workflow_service.get_published_workflow_by_id(app, workflow.id)
def test_get_published_workflow_by_id_not_found(self, db_session_with_containers): def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session):
""" """
Test retrieval when no workflow exists with the specified ID. Test retrieval when no workflow exists with the specified ID.
@ -379,7 +370,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is None assert result is None
def test_get_published_workflow_success(self, db_session_with_containers): def test_get_published_workflow_success(self, db_session_with_containers: Session):
""" """
Test successful retrieval of the current published workflow for an app. Test successful retrieval of the current published workflow for an app.
@ -395,10 +386,8 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Published version workflow.version = "2024.01.01.001" # Published version
from extensions.ext_database import db
app.workflow_id = workflow.id app.workflow_id = workflow.id
db.session.commit() db_session_with_containers.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -411,7 +400,7 @@ class TestWorkflowService:
assert result.version != Workflow.VERSION_DRAFT assert result.version != Workflow.VERSION_DRAFT
assert result.app_id == app.id assert result.app_id == app.id
def test_get_published_workflow_no_workflow_id(self, db_session_with_containers): def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session):
""" """
Test retrieval when app has no associated workflow ID. Test retrieval when app has no associated workflow ID.
@ -431,7 +420,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is None assert result is None
def test_get_all_published_workflow_pagination(self, db_session_with_containers): def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session):
""" """
Test pagination of published workflows. Test pagination of published workflows.
@ -455,15 +444,13 @@ class TestWorkflowService:
# Set the app's workflow_id to the first workflow # Set the app's workflow_id to the first workflow
app.workflow_id = workflows[0].id app.workflow_id = workflows[0].id
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Act - First page # Act - First page
result_workflows, has_more = workflow_service.get_all_published_workflow( result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session, session=db_session_with_containers,
app_model=app, app_model=app,
page=1, page=1,
limit=3, limit=3,
@ -476,7 +463,7 @@ class TestWorkflowService:
# Act - Second page # Act - Second page
result_workflows, has_more = workflow_service.get_all_published_workflow( result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session, session=db_session_with_containers,
app_model=app, app_model=app,
page=2, page=2,
limit=3, limit=3,
@ -487,7 +474,7 @@ class TestWorkflowService:
assert len(result_workflows) == 2 assert len(result_workflows) == 2
assert has_more is False assert has_more is False
def test_get_all_published_workflow_user_filter(self, db_session_with_containers): def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session):
""" """
Test filtering published workflows by user. Test filtering published workflows by user.
@ -513,22 +500,20 @@ class TestWorkflowService:
# Set the app's workflow_id to the first workflow # Set the app's workflow_id to the first workflow
app.workflow_id = workflow1.id app.workflow_id = workflow1.id
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Act - Filter by account1 # Act - Filter by account1
result_workflows, has_more = workflow_service.get_all_published_workflow( result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session, app_model=app, page=1, limit=10, user_id=account1.id session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id
) )
# Assert # Assert
assert len(result_workflows) == 1 assert len(result_workflows) == 1
assert result_workflows[0].created_by == account1.id assert result_workflows[0].created_by == account1.id
def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers): def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session):
""" """
Test filtering published workflows to show only named workflows. Test filtering published workflows to show only named workflows.
@ -557,22 +542,20 @@ class TestWorkflowService:
# Set the app's workflow_id to the first workflow # Set the app's workflow_id to the first workflow
app.workflow_id = workflow1.id app.workflow_id = workflow1.id
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Act - Filter named only # Act - Filter named only
result_workflows, has_more = workflow_service.get_all_published_workflow( result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True
) )
# Assert # Assert
assert len(result_workflows) == 2 assert len(result_workflows) == 2
assert all(wf.marked_name for wf in result_workflows) assert all(wf.marked_name for wf in result_workflows)
def test_sync_draft_workflow_create_new(self, db_session_with_containers): def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session):
""" """
Test creating a new draft workflow through sync operation. Test creating a new draft workflow through sync operation.
@ -624,7 +607,7 @@ class TestWorkflowService:
assert result.features == json.dumps(features) assert result.features == json.dumps(features)
assert result.created_by == account.id assert result.created_by == account.id
def test_sync_draft_workflow_update_existing(self, db_session_with_containers): def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session):
""" """
Test updating an existing draft workflow through sync operation. Test updating an existing draft workflow through sync operation.
@ -688,7 +671,7 @@ class TestWorkflowService:
assert result.features == json.dumps(new_features) assert result.features == json.dumps(new_features)
assert result.updated_by == account.id assert result.updated_by == account.id
def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers): def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session):
""" """
Test error when sync is attempted with mismatched hash. Test error when sync is attempted with mismatched hash.
@ -738,7 +721,7 @@ class TestWorkflowService:
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
) )
def test_publish_workflow_success(self, db_session_with_containers): def test_publish_workflow_success(self, db_session_with_containers: Session):
""" """
Test successful workflow publishing. Test successful workflow publishing.
@ -755,9 +738,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = Workflow.VERSION_DRAFT workflow.version = Workflow.VERSION_DRAFT
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -777,7 +758,7 @@ class TestWorkflowService:
assert len(result.version) > 10 # Should be a reasonable timestamp length assert len(result.version) > 10 # Should be a reasonable timestamp length
assert result.created_by == account.id assert result.created_by == account.id
def test_publish_workflow_no_draft_error(self, db_session_with_containers): def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session):
""" """
Test error when publishing workflow without draft. Test error when publishing workflow without draft.
@ -797,7 +778,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="No valid workflow found"): with pytest.raises(ValueError, match="No valid workflow found"):
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
def test_publish_workflow_already_published_error(self, db_session_with_containers): def test_publish_workflow_already_published_error(self, db_session_with_containers: Session):
""" """
Test error when publishing already published workflow. Test error when publishing already published workflow.
@ -813,9 +794,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Already published workflow.version = "2024.01.01.001" # Already published
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -823,7 +802,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="No valid workflow found"): with pytest.raises(ValueError, match="No valid workflow found"):
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
def test_get_default_block_configs(self, db_session_with_containers): def test_get_default_block_configs(self, db_session_with_containers: Session):
""" """
Test retrieval of default block configurations for all node types. Test retrieval of default block configurations for all node types.
@ -847,7 +826,7 @@ class TestWorkflowService:
assert isinstance(config, dict) assert isinstance(config, dict)
# The structure can vary, so we just check it's a dict # The structure can vary, so we just check it's a dict
def test_get_default_block_config_specific_type(self, db_session_with_containers): def test_get_default_block_config_specific_type(self, db_session_with_containers: Session):
""" """
Test retrieval of default block configuration for a specific node type. Test retrieval of default block configuration for a specific node type.
@ -867,7 +846,7 @@ class TestWorkflowService:
# This is acceptable behavior # This is acceptable behavior
assert result is None or isinstance(result, dict) assert result is None or isinstance(result, dict)
def test_get_default_block_config_invalid_type(self, db_session_with_containers): def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session):
""" """
Test retrieval of default block configuration for invalid node type. Test retrieval of default block configuration for invalid node type.
@ -887,7 +866,7 @@ class TestWorkflowService:
# It's also acceptable for the service to raise a ValueError for invalid types # It's also acceptable for the service to raise a ValueError for invalid types
pass pass
def test_get_default_block_config_with_filters(self, db_session_with_containers): def test_get_default_block_config_with_filters(self, db_session_with_containers: Session):
""" """
Test retrieval of default block configuration with filters. Test retrieval of default block configuration with filters.
@ -907,7 +886,7 @@ class TestWorkflowService:
# Result might be None if filters don't match, but should not raise error # Result might be None if filters don't match, but should not raise error
assert result is None or isinstance(result, dict) assert result is None or isinstance(result, dict)
def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers): def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session):
""" """
Test successful conversion from chat mode app to workflow mode. Test successful conversion from chat mode app to workflow mode.
@ -944,11 +923,9 @@ class TestWorkflowService:
) )
app_model_config.id = fake.uuid4() app_model_config.id = fake.uuid4()
from extensions.ext_database import db db_session_with_containers.add(app_model_config)
db.session.add(app_model_config)
app.app_model_config_id = app_model_config.id app.app_model_config_id = app_model_config.id
db.session.commit() db_session_with_containers.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
conversion_args = { conversion_args = {
@ -969,7 +946,7 @@ class TestWorkflowService:
assert result.icon_type == conversion_args["icon_type"] assert result.icon_type == conversion_args["icon_type"]
assert result.icon_background == conversion_args["icon_background"] assert result.icon_background == conversion_args["icon_background"]
def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers): def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session):
""" """
Test successful conversion from completion mode app to workflow mode. Test successful conversion from completion mode app to workflow mode.
@ -1006,11 +983,9 @@ class TestWorkflowService:
) )
app_model_config.id = fake.uuid4() app_model_config.id = fake.uuid4()
from extensions.ext_database import db db_session_with_containers.add(app_model_config)
db.session.add(app_model_config)
app.app_model_config_id = app_model_config.id app.app_model_config_id = app_model_config.id
db.session.commit() db_session_with_containers.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
conversion_args = { conversion_args = {
@ -1031,7 +1006,7 @@ class TestWorkflowService:
assert result.icon_type == conversion_args["icon_type"] assert result.icon_type == conversion_args["icon_type"]
assert result.icon_background == conversion_args["icon_background"] assert result.icon_background == conversion_args["icon_background"]
def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers): def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session):
""" """
Test error when attempting to convert unsupported app mode. Test error when attempting to convert unsupported app mode.
@ -1046,9 +1021,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW app.mode = AppMode.WORKFLOW
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
conversion_args = {"name": "Test"} conversion_args = {"name": "Test"}
@ -1057,7 +1030,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"): with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"):
workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args) workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args)
def test_validate_features_structure_advanced_chat(self, db_session_with_containers): def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session):
""" """
Test feature structure validation for advanced chat mode apps. Test feature structure validation for advanced chat mode apps.
@ -1069,9 +1042,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.ADVANCED_CHAT app.mode = AppMode.ADVANCED_CHAT
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
features = { features = {
@ -1088,7 +1059,7 @@ class TestWorkflowService:
# The exact behavior depends on the AdvancedChatAppConfigManager implementation # The exact behavior depends on the AdvancedChatAppConfigManager implementation
assert result is not None or isinstance(result, dict) assert result is not None or isinstance(result, dict)
def test_validate_features_structure_workflow(self, db_session_with_containers): def test_validate_features_structure_workflow(self, db_session_with_containers: Session):
""" """
Test feature structure validation for workflow mode apps. Test feature structure validation for workflow mode apps.
@ -1100,9 +1071,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW app.mode = AppMode.WORKFLOW
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
features = {"workflow_config": {"max_steps": 10, "timeout": 300}} features = {"workflow_config": {"max_steps": 10, "timeout": 300}}
@ -1115,7 +1084,7 @@ class TestWorkflowService:
# The exact behavior depends on the WorkflowAppConfigManager implementation # The exact behavior depends on the WorkflowAppConfigManager implementation
assert result is not None or isinstance(result, dict) assert result is not None or isinstance(result, dict)
def test_validate_features_structure_invalid_mode(self, db_session_with_containers): def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session):
""" """
Test error when validating features for invalid app mode. Test error when validating features for invalid app mode.
@ -1127,9 +1096,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
app.mode = "invalid_mode" # Invalid mode app.mode = "invalid_mode" # Invalid mode
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
features = {"test": "value"} features = {"test": "value"}
@ -1138,7 +1105,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"):
workflow_service.validate_features_structure(app_model=app, features=features) workflow_service.validate_features_structure(app_model=app, features=features)
def test_update_workflow_success(self, db_session_with_containers): def test_update_workflow_success(self, db_session_with_containers: Session):
""" """
Test successful workflow update with allowed fields. Test successful workflow update with allowed fields.
@ -1152,16 +1119,14 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"} update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"}
# Act # Act
result = workflow_service.update_workflow( result = workflow_service.update_workflow(
session=db.session, session=db_session_with_containers,
workflow_id=workflow.id, workflow_id=workflow.id,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
account_id=account.id, account_id=account.id,
@ -1174,7 +1139,7 @@ class TestWorkflowService:
assert result.marked_comment == update_data["marked_comment"] assert result.marked_comment == update_data["marked_comment"]
assert result.updated_by == account.id assert result.updated_by == account.id
def test_update_workflow_not_found(self, db_session_with_containers): def test_update_workflow_not_found(self, db_session_with_containers: Session):
""" """
Test workflow update when workflow doesn't exist. Test workflow update when workflow doesn't exist.
@ -1186,15 +1151,13 @@ class TestWorkflowService:
account = self._create_test_account(db_session_with_containers, fake) account = self._create_test_account(db_session_with_containers, fake)
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
from extensions.ext_database import db
workflow_service = WorkflowService() workflow_service = WorkflowService()
non_existent_workflow_id = fake.uuid4() non_existent_workflow_id = fake.uuid4()
update_data = {"marked_name": "Test"} update_data = {"marked_name": "Test"}
# Act # Act
result = workflow_service.update_workflow( result = workflow_service.update_workflow(
session=db.session, session=db_session_with_containers,
workflow_id=non_existent_workflow_id, workflow_id=non_existent_workflow_id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
account_id=account.id, account_id=account.id,
@ -1204,7 +1167,7 @@ class TestWorkflowService:
# Assert # Assert
assert result is None assert result is None
def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers): def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session):
""" """
Test that workflow update ignores disallowed fields. Test that workflow update ignores disallowed fields.
@ -1218,9 +1181,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
original_name = workflow.marked_name original_name = workflow.marked_name
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
update_data = { update_data = {
@ -1231,7 +1192,7 @@ class TestWorkflowService:
# Act # Act
result = workflow_service.update_workflow( result = workflow_service.update_workflow(
session=db.session, session=db_session_with_containers,
workflow_id=workflow.id, workflow_id=workflow.id,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
account_id=account.id, account_id=account.id,
@ -1245,7 +1206,7 @@ class TestWorkflowService:
assert result.graph == workflow.graph assert result.graph == workflow.graph
assert result.features == workflow.features assert result.features == workflow.features
def test_delete_workflow_success(self, db_session_with_containers): def test_delete_workflow_success(self, db_session_with_containers: Session):
""" """
Test successful workflow deletion. Test successful workflow deletion.
@ -1262,25 +1223,23 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Published version workflow.version = "2024.01.01.001" # Published version
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Act # Act
result = workflow_service.delete_workflow( result = workflow_service.delete_workflow(
session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
) )
# Assert # Assert
assert result is True assert result is True
# Verify workflow is actually deleted # Verify workflow is actually deleted
deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first() deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first()
assert deleted_workflow is None assert deleted_workflow is None
def test_delete_workflow_draft_error(self, db_session_with_containers): def test_delete_workflow_draft_error(self, db_session_with_containers: Session):
""" """
Test error when attempting to delete a draft workflow. Test error when attempting to delete a draft workflow.
@ -1296,9 +1255,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
# Keep as draft version # Keep as draft version
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -1306,9 +1263,11 @@ class TestWorkflowService:
from services.errors.workflow_service import DraftWorkflowDeletionError from services.errors.workflow_service import DraftWorkflowDeletionError
with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"): with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"):
workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) workflow_service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
)
def test_delete_workflow_in_use_error(self, db_session_with_containers): def test_delete_workflow_in_use_error(self, db_session_with_containers: Session):
""" """
Test error when attempting to delete a workflow that's in use by an app. Test error when attempting to delete a workflow that's in use by an app.
@ -1327,9 +1286,7 @@ class TestWorkflowService:
# Associate workflow with app # Associate workflow with app
app.workflow_id = workflow.id app.workflow_id = workflow.id
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -1337,9 +1294,11 @@ class TestWorkflowService:
from services.errors.workflow_service import WorkflowInUseError from services.errors.workflow_service import WorkflowInUseError
with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"): with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"):
workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) workflow_service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
)
def test_delete_workflow_not_found_error(self, db_session_with_containers): def test_delete_workflow_not_found_error(self, db_session_with_containers: Session):
""" """
Test error when attempting to delete a non-existent workflow. Test error when attempting to delete a non-existent workflow.
@ -1351,17 +1310,15 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake)
non_existent_workflow_id = fake.uuid4() non_existent_workflow_id = fake.uuid4()
from extensions.ext_database import db
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Act & Assert # Act & Assert
with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"): with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"):
workflow_service.delete_workflow( workflow_service.delete_workflow(
session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id
) )
def test_run_free_workflow_node_success(self, db_session_with_containers): def test_run_free_workflow_node_success(self, db_session_with_containers: Session):
""" """
Test successful execution of a free workflow node. Test successful execution of a free workflow node.
@ -1413,7 +1370,7 @@ class TestWorkflowService:
assert result.workflow_id == "" # No workflow ID for free nodes assert result.workflow_id == "" # No workflow ID for free nodes
assert result.index == 1 assert result.index == 1
def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers): def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session):
""" """
Test execution of a free workflow node with complex input data. Test execution of a free workflow node with complex input data.
@ -1454,7 +1411,7 @@ class TestWorkflowService:
error_msg = str(exc_info.value).lower() error_msg = str(exc_info.value).lower()
assert any(keyword in error_msg for keyword in ["start", "not supported", "external"]) assert any(keyword in error_msg for keyword in ["start", "not supported", "external"])
def test_handle_node_run_result_success(self, db_session_with_containers): def test_handle_node_run_result_success(self, db_session_with_containers: Session):
""" """
Test successful handling of node run results. Test successful handling of node run results.
@ -1529,7 +1486,7 @@ class TestWorkflowService:
assert result.outputs is not None assert result.outputs is not None
assert result.process_data is not None assert result.process_data is not None
def test_handle_node_run_result_failure(self, db_session_with_containers): def test_handle_node_run_result_failure(self, db_session_with_containers: Session):
""" """
Test handling of failed node run results. Test handling of failed node run results.
@ -1598,7 +1555,7 @@ class TestWorkflowService:
assert result.error is not None assert result.error is not None
assert "Test error message" in str(result.error) assert "Test error message" in str(result.error)
def test_handle_node_run_result_continue_on_error(self, db_session_with_containers): def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session):
""" """
Test handling of node run results with continue_on_error strategy. Test handling of node run results with continue_on_error strategy.

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from services.workspace_service import WorkspaceService from services.workspace_service import WorkspaceService
@ -29,7 +30,7 @@ class TestWorkspaceService:
"dify_config": mock_dify_config, "dify_config": mock_dify_config,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -50,10 +51,8 @@ class TestWorkspaceService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant # Create tenant
tenant = Tenant( tenant = Tenant(
@ -62,8 +61,8 @@ class TestWorkspaceService:
plan="basic", plan="basic",
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}', custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join with owner role # Create tenant-account join with owner role
join = TenantAccountJoin( join = TenantAccountJoin(
@ -72,15 +71,15 @@ class TestWorkspaceService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of tenant information with all features enabled. Test successful retrieval of tenant information with all features enabled.
@ -121,13 +120,12 @@ class TestWorkspaceService:
assert "replace_webapp_logo" in result["custom_config"] assert "replace_webapp_logo" in result["custom_config"]
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_without_custom_config( def test_get_tenant_info_without_custom_config(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval when custom config features are disabled. Test tenant info retrieval when custom config features are disabled.
@ -167,13 +165,12 @@ class TestWorkspaceService:
assert "custom_config" not in result assert "custom_config" not in result
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_with_normal_user_role( def test_get_tenant_info_with_normal_user_role(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval for normal user role without privileged features. Test tenant info retrieval for normal user role without privileged features.
@ -191,11 +188,14 @@ class TestWorkspaceService:
) )
# Update the join to have normal role # Update the join to have normal role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.NORMAL join.role = TenantAccountRole.NORMAL
db.session.commit() db_session_with_containers.commit()
# Setup mocks for feature service # Setup mocks for feature service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -220,11 +220,11 @@ class TestWorkspaceService:
assert "custom_config" not in result assert "custom_config" not in result
# Verify database state # Verify database state
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_with_admin_role_and_logo_replacement( def test_get_tenant_info_with_admin_role_and_logo_replacement(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval for admin role with logo replacement enabled. Test tenant info retrieval for admin role with logo replacement enabled.
@ -242,11 +242,14 @@ class TestWorkspaceService:
) )
# Update the join to have admin role # Update the join to have admin role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.ADMIN join.role = TenantAccountRole.ADMIN
db.session.commit() db_session_with_containers.commit()
# Setup mocks for feature service and tenant service # Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -268,10 +271,12 @@ class TestWorkspaceService:
assert "replace_webapp_logo" in result["custom_config"] assert "replace_webapp_logo" in result["custom_config"]
# Verify database state # Verify database state
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies): def test_get_tenant_info_with_tenant_none(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tenant info retrieval when tenant parameter is None. Test tenant info retrieval when tenant parameter is None.
@ -290,7 +295,7 @@ class TestWorkspaceService:
assert result is None assert result is None
def test_get_tenant_info_with_custom_config_variations( def test_get_tenant_info_with_custom_config_variations(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval with various custom config configurations. Test tenant info retrieval with various custom config configurations.
@ -323,10 +328,8 @@ class TestWorkspaceService:
# Update tenant custom config # Update tenant custom config
import json import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config) tenant.custom_config = json.dumps(config)
db.session.commit() db_session_with_containers.commit()
# Setup mocks # Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -353,11 +356,11 @@ class TestWorkspaceService:
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"] assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
# Verify database state # Verify database state
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_with_editor_role_and_limited_permissions( def test_get_tenant_info_with_editor_role_and_limited_permissions(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval for editor role with limited permissions. Test tenant info retrieval for editor role with limited permissions.
@ -375,11 +378,14 @@ class TestWorkspaceService:
) )
# Update the join to have editor role # Update the join to have editor role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.EDITOR join.role = TenantAccountRole.EDITOR
db.session.commit() db_session_with_containers.commit()
# Setup mocks for feature service and tenant service # Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -400,11 +406,11 @@ class TestWorkspaceService:
assert "custom_config" not in result assert "custom_config" not in result
# Verify database state # Verify database state
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_with_dataset_operator_role( def test_get_tenant_info_with_dataset_operator_role(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval for dataset operator role. Test tenant info retrieval for dataset operator role.
@ -422,11 +428,14 @@ class TestWorkspaceService:
) )
# Update the join to have dataset operator role # Update the join to have dataset operator role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.DATASET_OPERATOR join.role = TenantAccountRole.DATASET_OPERATOR
db.session.commit() db_session_with_containers.commit()
# Setup mocks for feature service and tenant service # Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -447,11 +456,11 @@ class TestWorkspaceService:
assert "custom_config" not in result assert "custom_config" not in result
# Verify database state # Verify database state
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None
def test_get_tenant_info_with_complex_custom_config_scenarios( def test_get_tenant_info_with_complex_custom_config_scenarios(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant info retrieval with complex custom config scenarios. Test tenant info retrieval with complex custom config scenarios.
@ -491,10 +500,8 @@ class TestWorkspaceService:
# Update tenant custom config # Update tenant custom config
import json import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config) tenant.custom_config = json.dumps(config)
db.session.commit() db_session_with_containers.commit()
# Setup mocks # Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -525,5 +532,5 @@ class TestWorkspaceService:
assert result["custom_config"]["remove_webapp_brand"] is False assert result["custom_config"]["remove_webapp_brand"] is False
# Verify database state # Verify database state
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
assert tenant.id is not None assert tenant.id is not None

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from pydantic import TypeAdapter, ValidationError from pydantic import TypeAdapter, ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ApiProviderSchemaType from core.tools.entities.tool_entities import ApiProviderSchemaType
from models import Account, Tenant from models import Account, Tenant
@ -34,7 +35,7 @@ class TestApiToolManageService:
"provider_controller": mock_provider_controller, "provider_controller": mock_provider_controller,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -55,18 +56,16 @@ class TestApiToolManageService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
@ -77,8 +76,8 @@ class TestApiToolManageService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
@ -118,7 +117,7 @@ class TestApiToolManageService:
""" """
def test_parser_api_schema_success( def test_parser_api_schema_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful parsing of API schema. Test successful parsing of API schema.
@ -163,7 +162,7 @@ class TestApiToolManageService:
assert api_key_value_field["default"] == "" assert api_key_value_field["default"] == ""
def test_parser_api_schema_invalid_schema( def test_parser_api_schema_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test parsing of invalid API schema. Test parsing of invalid API schema.
@ -183,7 +182,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value) assert "invalid schema" in str(exc_info.value)
def test_parser_api_schema_malformed_json( def test_parser_api_schema_malformed_json(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test parsing of malformed JSON schema. Test parsing of malformed JSON schema.
@ -203,7 +202,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value) assert "invalid schema" in str(exc_info.value)
def test_convert_schema_to_tool_bundles_success( def test_convert_schema_to_tool_bundles_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion of schema to tool bundles. Test successful conversion of schema to tool bundles.
@ -233,7 +232,7 @@ class TestApiToolManageService:
assert tool_bundle.operation_id == "testOperation" assert tool_bundle.operation_id == "testOperation"
def test_convert_schema_to_tool_bundles_with_extra_info( def test_convert_schema_to_tool_bundles_with_extra_info(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion of schema to tool bundles with extra info. Test successful conversion of schema to tool bundles with extra info.
@ -259,7 +258,7 @@ class TestApiToolManageService:
assert isinstance(schema_type, str) assert isinstance(schema_type, str)
def test_convert_schema_to_tool_bundles_invalid_schema( def test_convert_schema_to_tool_bundles_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test conversion of invalid schema to tool bundles. Test conversion of invalid schema to tool bundles.
@ -279,7 +278,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value) assert "invalid schema" in str(exc_info.value)
def test_create_api_tool_provider_success( def test_create_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful creation of API tool provider. Test successful creation of API tool provider.
@ -324,10 +323,9 @@ class TestApiToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
# Verify database state # Verify database state
from extensions.ext_database import db
provider = ( provider = (
db.session.query(ApiToolProvider) db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first() .first()
) )
@ -347,7 +345,7 @@ class TestApiToolManageService:
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once() mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
def test_create_api_tool_provider_duplicate_name( def test_create_api_tool_provider_duplicate_name(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creation of API tool provider with duplicate name. Test creation of API tool provider with duplicate name.
@ -404,7 +402,7 @@ class TestApiToolManageService:
assert f"provider {provider_name} already exists" in str(exc_info.value) assert f"provider {provider_name} already exists" in str(exc_info.value)
def test_create_api_tool_provider_invalid_schema_type( def test_create_api_tool_provider_invalid_schema_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creation of API tool provider with invalid schema type. Test creation of API tool provider with invalid schema type.
@ -436,7 +434,7 @@ class TestApiToolManageService:
assert "validation error" in str(exc_info.value) assert "validation error" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type( def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test creation of API tool provider with missing auth type. Test creation of API tool provider with missing auth type.
@ -479,7 +477,7 @@ class TestApiToolManageService:
assert "auth_type is required" in str(exc_info.value) assert "auth_type is required" in str(exc_info.value)
def test_create_api_tool_provider_with_api_key_auth( def test_create_api_tool_provider_with_api_key_auth(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful creation of API tool provider with API key authentication. Test successful creation of API tool provider with API key authentication.
@ -522,10 +520,9 @@ class TestApiToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
# Verify database state # Verify database state
from extensions.ext_database import db
provider = ( provider = (
db.session.query(ApiToolProvider) db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first() .first()
) )

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from models import Account, Tenant from models import Account, Tenant
@ -41,7 +42,7 @@ class TestMCPToolManageService:
"tool_transform_service": mock_tool_transform_service, "tool_transform_service": mock_tool_transform_service,
} }
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -62,18 +63,16 @@ class TestMCPToolManageService:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
@ -84,8 +83,8 @@ class TestMCPToolManageService:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
@ -93,7 +92,7 @@ class TestMCPToolManageService:
return account, tenant return account, tenant
def _create_test_mcp_provider( def _create_test_mcp_provider(
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id
): ):
""" """
Helper method to create a test MCP tool provider for testing. Helper method to create a test MCP tool provider for testing.
@ -124,15 +123,13 @@ class TestMCPToolManageService:
sse_read_timeout=300.0, sse_read_timeout=300.0,
) )
from extensions.ext_database import db db_session_with_containers.add(mcp_provider)
db_session_with_containers.commit()
db.session.add(mcp_provider)
db.session.commit()
return mcp_provider return mcp_provider
def test_get_mcp_provider_by_provider_id_success( def test_get_mcp_provider_by_provider_id_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful retrieval of MCP provider by provider ID. Test successful retrieval of MCP provider by provider ID.
@ -153,9 +150,8 @@ class TestMCPToolManageService:
) )
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id) result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -166,12 +162,12 @@ class TestMCPToolManageService:
assert result.user_id == account.id assert result.user_id == account.id
# Verify database state # Verify database state
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
assert result.server_identifier == mcp_provider.server_identifier assert result.server_identifier == mcp_provider.server_identifier
def test_get_mcp_provider_by_provider_id_not_found( def test_get_mcp_provider_by_provider_id_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when MCP provider is not found by provider ID. Test error handling when MCP provider is not found by provider ID.
@ -190,14 +186,13 @@ class TestMCPToolManageService:
non_existent_id = str(fake.uuid4()) non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"): with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id) service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
def test_get_mcp_provider_by_provider_id_tenant_isolation( def test_get_mcp_provider_by_provider_id_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant isolation when retrieving MCP provider by provider ID. Test tenant isolation when retrieving MCP provider by provider ID.
@ -223,14 +218,13 @@ class TestMCPToolManageService:
) )
# Act & Assert: Verify tenant isolation # Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"): with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id) service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
def test_get_mcp_provider_by_server_identifier_success( def test_get_mcp_provider_by_server_identifier_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful retrieval of MCP provider by server identifier. Test successful retrieval of MCP provider by server identifier.
@ -251,9 +245,8 @@ class TestMCPToolManageService:
) )
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id) result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -264,12 +257,12 @@ class TestMCPToolManageService:
assert result.user_id == account.id assert result.user_id == account.id
# Verify database state # Verify database state
db.session.refresh(result) db_session_with_containers.refresh(result)
assert result.id is not None assert result.id is not None
assert result.name == mcp_provider.name assert result.name == mcp_provider.name
def test_get_mcp_provider_by_server_identifier_not_found( def test_get_mcp_provider_by_server_identifier_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when MCP provider is not found by server identifier. Test error handling when MCP provider is not found by server identifier.
@ -288,14 +281,13 @@ class TestMCPToolManageService:
non_existent_identifier = str(fake.uuid4()) non_existent_identifier = str(fake.uuid4())
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"): with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id) service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
def test_get_mcp_provider_by_server_identifier_tenant_isolation( def test_get_mcp_provider_by_server_identifier_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tenant isolation when retrieving MCP provider by server identifier. Test tenant isolation when retrieving MCP provider by server identifier.
@ -321,13 +313,12 @@ class TestMCPToolManageService:
) )
# Act & Assert: Verify tenant isolation # Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"): with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id) service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful creation of MCP provider. Test successful creation of MCP provider.
@ -365,9 +356,8 @@ class TestMCPToolManageService:
# Act: Execute the method under test # Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result = service.create_provider( result = service.create_provider(
tenant_id=tenant.id, tenant_id=tenant.id,
name="Test MCP Provider", name="Test MCP Provider",
@ -389,10 +379,9 @@ class TestMCPToolManageService:
assert result.type == ToolProviderType.MCP assert result.type == ToolProviderType.MCP
# Verify database state # Verify database state
from extensions.ext_database import db
created_provider = ( created_provider = (
db.session.query(MCPToolProvider) db_session_with_containers.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider") .filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider")
.first() .first()
) )
@ -410,7 +399,9 @@ class TestMCPToolManageService:
) )
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once() mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once()
def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): def test_create_mcp_provider_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test error handling when creating MCP provider with duplicate name. Test error handling when creating MCP provider with duplicate name.
@ -427,9 +418,8 @@ class TestMCPToolManageService:
# Create first provider # Create first provider
from core.entities.mcp_provider import MCPConfiguration from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.create_provider( service.create_provider(
tenant_id=tenant.id, tenant_id=tenant.id,
name="Test MCP Provider", name="Test MCP Provider",
@ -463,7 +453,7 @@ class TestMCPToolManageService:
) )
def test_create_mcp_provider_duplicate_server_url( def test_create_mcp_provider_duplicate_server_url(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when creating MCP provider with duplicate server URL. Test error handling when creating MCP provider with duplicate server URL.
@ -481,9 +471,8 @@ class TestMCPToolManageService:
# Create first provider # Create first provider
from core.entities.mcp_provider import MCPConfiguration from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.create_provider( service.create_provider(
tenant_id=tenant.id, tenant_id=tenant.id,
name="Test MCP Provider 1", name="Test MCP Provider 1",
@ -517,7 +506,7 @@ class TestMCPToolManageService:
) )
def test_create_mcp_provider_duplicate_server_identifier( def test_create_mcp_provider_duplicate_server_identifier(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when creating MCP provider with duplicate server identifier. Test error handling when creating MCP provider with duplicate server identifier.
@ -535,9 +524,8 @@ class TestMCPToolManageService:
# Create first provider # Create first provider
from core.entities.mcp_provider import MCPConfiguration from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.create_provider( service.create_provider(
tenant_id=tenant.id, tenant_id=tenant.id,
name="Test MCP Provider 1", name="Test MCP Provider 1",
@ -570,7 +558,7 @@ class TestMCPToolManageService:
), ),
) )
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies): def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful retrieval of MCP tools for a tenant. Test successful retrieval of MCP tools for a tenant.
@ -602,9 +590,7 @@ class TestMCPToolManageService:
) )
provider3.name = "Gamma Provider" provider3.name = "Gamma Provider"
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Setup mock for transformation service # Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
@ -647,9 +633,8 @@ class TestMCPToolManageService:
] ]
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result = service.list_providers(tenant_id=tenant.id, for_list=True) result = service.list_providers(tenant_id=tenant.id, for_list=True)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -666,7 +651,9 @@ class TestMCPToolManageService:
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3 mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3
) )
def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies): def test_retrieve_mcp_tools_empty_list(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test retrieval of MCP tools when tenant has no providers. Test retrieval of MCP tools when tenant has no providers.
@ -684,9 +671,8 @@ class TestMCPToolManageService:
# No MCP providers created for this tenant # No MCP providers created for this tenant
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result = service.list_providers(tenant_id=tenant.id, for_list=False) result = service.list_providers(tenant_id=tenant.id, for_list=False)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -697,7 +683,9 @@ class TestMCPToolManageService:
# Verify no transformation service calls for empty list # Verify no transformation service calls for empty list
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called() mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called()
def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): def test_retrieve_mcp_tools_tenant_isolation(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tenant isolation when retrieving MCP tools. Test tenant isolation when retrieving MCP tools.
@ -756,9 +744,8 @@ class TestMCPToolManageService:
] ]
# Act: Execute the method under test for both tenants # Act: Execute the method under test for both tenants
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True) result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True) result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
@ -769,7 +756,7 @@ class TestMCPToolManageService:
assert result2[0].id == provider2.id assert result2[0].id == provider2.id
def test_list_mcp_tool_from_remote_server_success( def test_list_mcp_tool_from_remote_server_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful listing of MCP tools from remote server. Test successful listing of MCP tools from remote server.
@ -797,9 +784,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True # Provider must be authenticated to list tools mcp_provider.authed = True # Provider must be authenticated to list tools
mcp_provider.tools = "[]" mcp_provider.tools = "[]"
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Mock the decryption process at the rsa level to avoid key file issues # Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt: with patch("libs.rsa.decrypt") as mock_decrypt:
@ -821,9 +806,8 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.return_value = mock_tools mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -834,7 +818,7 @@ class TestMCPToolManageService:
# Note: server_url is mocked, so we skip that assertion to avoid encryption issues # Note: server_url is mocked, so we skip that assertion to avoid encryption issues
# Verify database state was updated # Verify database state was updated
db.session.refresh(mcp_provider) db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True assert mcp_provider.authed is True
assert mcp_provider.tools != "[]" assert mcp_provider.tools != "[]"
assert mcp_provider.updated_at is not None assert mcp_provider.updated_at is not None
@ -844,7 +828,7 @@ class TestMCPToolManageService:
mock_mcp_client.assert_called_once() mock_mcp_client.assert_called_once()
def test_list_mcp_tool_from_remote_server_auth_error( def test_list_mcp_tool_from_remote_server_auth_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when MCP server requires authentication. Test error handling when MCP server requires authentication.
@ -871,9 +855,7 @@ class TestMCPToolManageService:
mcp_provider.authed = False mcp_provider.authed = False
mcp_provider.tools = "[]" mcp_provider.tools = "[]"
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Mock the decryption process at the rsa level to avoid key file issues # Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt: with patch("libs.rsa.decrypt") as mock_decrypt:
@ -887,19 +869,18 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="Please auth the tool first"): with pytest.raises(ValueError, match="Please auth the tool first"):
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed # Verify database state was not changed
db.session.refresh(mcp_provider) db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is False assert mcp_provider.authed is False
assert mcp_provider.tools == "[]" assert mcp_provider.tools == "[]"
def test_list_mcp_tool_from_remote_server_connection_error( def test_list_mcp_tool_from_remote_server_connection_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when MCP server connection fails. Test error handling when MCP server connection fails.
@ -926,9 +907,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True # Provider must be authenticated to test connection errors mcp_provider.authed = True # Provider must be authenticated to test connection errors
mcp_provider.tools = "[]" mcp_provider.tools = "[]"
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Mock the decryption process at the rsa level to avoid key file issues # Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt: with patch("libs.rsa.decrypt") as mock_decrypt:
@ -942,18 +921,17 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.side_effect = MCPError("Connection failed") mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"): with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed # Verify database state was not changed
db.session.refresh(mcp_provider) db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True # Provider remains authenticated assert mcp_provider.authed is True # Provider remains authenticated
assert mcp_provider.tools == "[]" assert mcp_provider.tools == "[]"
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful deletion of MCP tool. Test successful deletion of MCP tool.
@ -974,20 +952,19 @@ class TestMCPToolManageService:
) )
# Verify provider exists # Verify provider exists
from extensions.ext_database import db
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
# Act: Execute the method under test # Act: Execute the method under test
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id) service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Provider should be deleted from database # Provider should be deleted from database
deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
assert deleted_provider is None assert deleted_provider is None
def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test error handling when deleting non-existent MCP tool. Test error handling when deleting non-existent MCP tool.
@ -1005,13 +982,14 @@ class TestMCPToolManageService:
non_existent_id = str(fake.uuid4()) non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"): with pytest.raises(ValueError, match="MCP tool not found"):
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id) service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): def test_delete_mcp_tool_tenant_isolation(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test tenant isolation when deleting MCP tool. Test tenant isolation when deleting MCP tool.
@ -1036,18 +1014,16 @@ class TestMCPToolManageService:
) )
# Act & Assert: Verify tenant isolation # Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"): with pytest.raises(ValueError, match="MCP tool not found"):
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id) service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
# Verify provider still exists in tenant1 # Verify provider still exists in tenant1
from extensions.ext_database import db
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful update of MCP provider. Test successful update of MCP provider.
@ -1070,14 +1046,12 @@ class TestMCPToolManageService:
original_name = mcp_provider.name original_name = mcp_provider.name
original_icon = mcp_provider.icon original_icon = mcp_provider.icon
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Act: Execute the method under test # Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.update_provider( service.update_provider(
tenant_id=tenant.id, tenant_id=tenant.id,
provider_id=mcp_provider.id, provider_id=mcp_provider.id,
@ -1094,7 +1068,7 @@ class TestMCPToolManageService:
) )
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
db.session.refresh(mcp_provider) db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.name == "Updated MCP Provider" assert mcp_provider.name == "Updated MCP Provider"
assert mcp_provider.server_identifier == "updated_identifier_123" assert mcp_provider.server_identifier == "updated_identifier_123"
assert mcp_provider.timeout == 45.0 assert mcp_provider.timeout == 45.0
@ -1108,7 +1082,9 @@ class TestMCPToolManageService:
assert icon_data["content"] == "🚀" assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4" assert icon_data["background"] == "#4ECDC4"
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): def test_update_mcp_provider_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test error handling when updating MCP provider with duplicate name. Test error handling when updating MCP provider with duplicate name.
@ -1134,15 +1110,12 @@ class TestMCPToolManageService:
) )
provider2.name = "Second Provider" provider2.name = "Second Provider"
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Act & Assert: Verify proper error handling for duplicate name # Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool First Provider already exists"): with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
service.update_provider( service.update_provider(
tenant_id=tenant.id, tenant_id=tenant.id,
@ -1160,7 +1133,7 @@ class TestMCPToolManageService:
) )
def test_update_mcp_provider_credentials_success( def test_update_mcp_provider_credentials_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful update of MCP provider credentials. Test successful update of MCP provider credentials.
@ -1185,9 +1158,7 @@ class TestMCPToolManageService:
mcp_provider.authed = False mcp_provider.authed = False
mcp_provider.tools = "[]" mcp_provider.tools = "[]"
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Mock the provider controller and encryption # Mock the provider controller and encryption
with ( with (
@ -1202,9 +1173,8 @@ class TestMCPToolManageService:
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.update_provider_credentials( service.update_provider_credentials(
provider_id=mcp_provider.id, provider_id=mcp_provider.id,
tenant_id=tenant.id, tenant_id=tenant.id,
@ -1213,7 +1183,7 @@ class TestMCPToolManageService:
) )
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
db.session.refresh(mcp_provider) db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True assert mcp_provider.authed is True
assert mcp_provider.updated_at is not None assert mcp_provider.updated_at is not None
@ -1225,7 +1195,7 @@ class TestMCPToolManageService:
assert "new_key" in credentials assert "new_key" in credentials
def test_update_mcp_provider_credentials_not_authed( def test_update_mcp_provider_credentials_not_authed(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test update of MCP provider credentials when not authenticated. Test update of MCP provider credentials when not authenticated.
@ -1249,9 +1219,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True mcp_provider.authed = True
mcp_provider.tools = '[{"name": "test_tool"}]' mcp_provider.tools = '[{"name": "test_tool"}]'
from extensions.ext_database import db db_session_with_containers.commit()
db.session.commit()
# Mock the provider controller and encryption # Mock the provider controller and encryption
with ( with (
@ -1266,9 +1234,8 @@ class TestMCPToolManageService:
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session()) service = MCPToolManageService(db_session_with_containers)
service.update_provider_credentials( service.update_provider_credentials(
provider_id=mcp_provider.id, provider_id=mcp_provider.id,
tenant_id=tenant.id, tenant_id=tenant.id,
@ -1277,12 +1244,14 @@ class TestMCPToolManageService:
) )
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
db.session.refresh(mcp_provider) db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is False assert mcp_provider.authed is False
assert mcp_provider.tools == "[]" assert mcp_provider.tools == "[]"
assert mcp_provider.updated_at is not None assert mcp_provider.updated_at is not None
def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): def test_re_connect_mcp_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful reconnection to MCP provider. Test successful reconnection to MCP provider.
@ -1343,7 +1312,9 @@ class TestMCPToolManageService:
sse_read_timeout=mcp_provider.sse_read_timeout, sse_read_timeout=mcp_provider.sse_read_timeout,
) )
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): def test_re_connect_mcp_provider_auth_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test reconnection to MCP provider when authentication fails. Test reconnection to MCP provider when authentication fails.
@ -1385,7 +1356,7 @@ class TestMCPToolManageService:
assert result.encrypted_credentials == "{}" assert result.encrypted_credentials == "{}"
def test_re_connect_mcp_provider_connection_error( def test_re_connect_mcp_provider_connection_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test reconnection to MCP provider when connection fails. Test reconnection to MCP provider when connection fails.

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -27,7 +28,7 @@ class TestToolTransformService:
} }
def _create_test_tool_provider( def _create_test_tool_provider(
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api" self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api"
): ):
""" """
Helper method to create a test tool provider for testing. Helper method to create a test tool provider for testing.
@ -89,14 +90,12 @@ class TestToolTransformService:
else: else:
raise ValueError(f"Unknown provider type: {provider_type}") raise ValueError(f"Unknown provider type: {provider_type}")
from extensions.ext_database import db db_session_with_containers.add(provider)
db_session_with_containers.commit()
db.session.add(provider)
db.session.commit()
return provider return provider
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies): def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful plugin icon URL generation. Test successful plugin icon URL generation.
@ -126,7 +125,7 @@ class TestToolTransformService:
assert result == expected_url assert result == expected_url
def test_get_plugin_icon_url_with_empty_console_url( def test_get_plugin_icon_url_with_empty_console_url(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test plugin icon URL generation when CONSOLE_API_URL is empty. Test plugin icon URL generation when CONSOLE_API_URL is empty.
@ -156,7 +155,7 @@ class TestToolTransformService:
assert result == expected_url assert result == expected_url
def test_get_tool_provider_icon_url_builtin_success( def test_get_tool_provider_icon_url_builtin_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful tool provider icon URL generation for builtin providers. Test successful tool provider icon URL generation for builtin providers.
@ -194,7 +193,7 @@ class TestToolTransformService:
assert result == expected_encoded assert result == expected_encoded
def test_get_tool_provider_icon_url_api_success( def test_get_tool_provider_icon_url_api_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful tool provider icon URL generation for API providers. Test successful tool provider icon URL generation for API providers.
@ -220,7 +219,7 @@ class TestToolTransformService:
assert result["content"] == "🔧" assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_api_invalid_json( def test_get_tool_provider_icon_url_api_invalid_json(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tool provider icon URL generation for API providers with invalid JSON. Test tool provider icon URL generation for API providers with invalid JSON.
@ -246,7 +245,7 @@ class TestToolTransformService:
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01" assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
def test_get_tool_provider_icon_url_workflow_success( def test_get_tool_provider_icon_url_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful tool provider icon URL generation for workflow providers. Test successful tool provider icon URL generation for workflow providers.
@ -271,7 +270,7 @@ class TestToolTransformService:
assert result["content"] == "🔧" assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_mcp_success( def test_get_tool_provider_icon_url_mcp_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful tool provider icon URL generation for MCP providers. Test successful tool provider icon URL generation for MCP providers.
@ -296,7 +295,7 @@ class TestToolTransformService:
assert result["content"] == "🔧" assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_unknown_type( def test_get_tool_provider_icon_url_unknown_type(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test tool provider icon URL generation for unknown provider types. Test tool provider icon URL generation for unknown provider types.
@ -317,7 +316,9 @@ class TestToolTransformService:
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
assert result == "" assert result == ""
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies): def test_repack_provider_dict_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful provider repacking with dictionary input. Test successful provider repacking with dictionary input.
@ -341,7 +342,9 @@ class TestToolTransformService:
# Note: provider name may contain spaces that get URL encoded # Note: provider name may contain spaces that get URL encoded
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"] assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies): def test_repack_provider_entity_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful provider repacking with ToolProviderApiEntity input. Test successful provider repacking with ToolProviderApiEntity input.
@ -389,7 +392,7 @@ class TestToolTransformService:
assert "test_icon_dark.png" in provider.icon_dark assert "test_icon_dark.png" in provider.icon_dark
def test_repack_provider_entity_no_plugin_success( def test_repack_provider_entity_no_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful provider repacking with ToolProviderApiEntity input without plugin_id. Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
@ -435,7 +438,9 @@ class TestToolTransformService:
assert provider.icon_dark["background"] == "#252525" assert provider.icon_dark["background"] == "#252525"
assert provider.icon_dark["content"] == "🔧" assert provider.icon_dark["content"] == "🔧"
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies): def test_repack_provider_entity_no_dark_icon(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test provider repacking with ToolProviderApiEntity input without dark icon. Test provider repacking with ToolProviderApiEntity input without dark icon.
@ -477,7 +482,7 @@ class TestToolTransformService:
assert provider.icon_dark == "" assert provider.icon_dark == ""
def test_builtin_provider_to_user_provider_success( def test_builtin_provider_to_user_provider_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion of builtin provider to user provider. Test successful conversion of builtin provider to user provider.
@ -545,7 +550,7 @@ class TestToolTransformService:
assert result.original_credentials == {"api_key": "decrypted_key"} assert result.original_credentials == {"api_key": "decrypted_key"}
def test_builtin_provider_to_user_provider_plugin_success( def test_builtin_provider_to_user_provider_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion of builtin provider to user provider with plugin. Test successful conversion of builtin provider to user provider with plugin.
@ -589,7 +594,7 @@ class TestToolTransformService:
assert result.allow_delete is False assert result.allow_delete is False
def test_builtin_provider_to_user_provider_no_credentials( def test_builtin_provider_to_user_provider_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test conversion of builtin provider to user provider without credentials. Test conversion of builtin provider to user provider without credentials.
@ -630,7 +635,9 @@ class TestToolTransformService:
assert result.allow_delete is False assert result.allow_delete is False
assert result.masked_credentials == {"api_key": ""} assert result.masked_credentials == {"api_key": ""}
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): def test_api_provider_to_controller_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful conversion of API provider to controller. Test successful conversion of API provider to controller.
@ -655,10 +662,8 @@ class TestToolTransformService:
tools_str="[]", tools_str="[]",
) )
from extensions.ext_database import db db_session_with_containers.add(provider)
db_session_with_containers.commit()
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test # Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider) result = ToolTransformService.api_provider_to_controller(provider)
@ -669,7 +674,7 @@ class TestToolTransformService:
# Additional assertions would depend on the actual controller implementation # Additional assertions would depend on the actual controller implementation
def test_api_provider_to_controller_api_key_query( def test_api_provider_to_controller_api_key_query(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test conversion of API provider to controller with api_key_query auth type. Test conversion of API provider to controller with api_key_query auth type.
@ -693,10 +698,8 @@ class TestToolTransformService:
tools_str="[]", tools_str="[]",
) )
from extensions.ext_database import db db_session_with_containers.add(provider)
db_session_with_containers.commit()
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test # Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider) result = ToolTransformService.api_provider_to_controller(provider)
@ -706,7 +709,7 @@ class TestToolTransformService:
assert hasattr(result, "from_db") assert hasattr(result, "from_db")
def test_api_provider_to_controller_backward_compatibility( def test_api_provider_to_controller_backward_compatibility(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test conversion of API provider to controller with backward compatibility auth types. Test conversion of API provider to controller with backward compatibility auth types.
@ -731,10 +734,8 @@ class TestToolTransformService:
tools_str="[]", tools_str="[]",
) )
from extensions.ext_database import db db_session_with_containers.add(provider)
db_session_with_containers.commit()
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test # Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider) result = ToolTransformService.api_provider_to_controller(provider)
@ -744,7 +745,7 @@ class TestToolTransformService:
assert hasattr(result, "from_db") assert hasattr(result, "from_db")
def test_workflow_provider_to_controller_success( def test_workflow_provider_to_controller_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion of workflow provider to controller. Test successful conversion of workflow provider to controller.
@ -769,10 +770,8 @@ class TestToolTransformService:
parameter_configuration="[]", parameter_configuration="[]",
) )
from extensions.ext_database import db db_session_with_containers.add(provider)
db_session_with_containers.commit()
db.session.add(provider)
db.session.commit()
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency # Mock the WorkflowToolProviderController.from_db method to avoid app dependency
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db: with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.errors import WorkflowToolHumanInputNotSupportedError
@ -63,7 +64,7 @@ class TestWorkflowToolManageService:
"tool_transform_service": mock_tool_transform_service, "tool_transform_service": mock_tool_transform_service,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test app and account for testing. Helper method to create a test app and account for testing.
@ -119,14 +120,12 @@ class TestWorkflowToolManageService:
conversation_variables=[], conversation_variables=[],
) )
from extensions.ext_database import db db_session_with_containers.add(workflow)
db_session_with_containers.commit()
db.session.add(workflow)
db.session.commit()
# Update app to reference the workflow # Update app to reference the workflow
app.workflow_id = workflow.id app.workflow_id = workflow.id
db.session.commit() db_session_with_containers.commit()
return app, account, workflow return app, account, workflow
@ -153,7 +152,9 @@ class TestWorkflowToolManageService:
), ),
] ]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful workflow tool creation with valid parameters. Test successful workflow tool creation with valid parameters.
@ -198,11 +199,10 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
# Verify database state # Verify database state
from extensions.ext_database import db
# Check if workflow tool provider was created # Check if workflow tool provider was created
created_tool_provider = ( created_tool_provider = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id, WorkflowToolProvider.app_id == app.id,
@ -230,7 +230,7 @@ class TestWorkflowToolManageService:
].workflow_provider_to_controller.assert_called_once() ].workflow_provider_to_controller.assert_called_once()
def test_create_workflow_tool_duplicate_name_error( def test_create_workflow_tool_duplicate_name_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation fails when name already exists. Test workflow tool creation fails when name already exists.
@ -280,10 +280,9 @@ class TestWorkflowToolManageService:
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created # Verify only one tool was created
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -293,7 +292,7 @@ class TestWorkflowToolManageService:
assert tool_count == 1 assert tool_count == 1
def test_create_workflow_tool_invalid_app_error( def test_create_workflow_tool_invalid_app_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation fails when app does not exist. Test workflow tool creation fails when app does not exist.
@ -331,10 +330,9 @@ class TestWorkflowToolManageService:
assert f"App {non_existent_app_id} not found" in str(exc_info.value) assert f"App {non_existent_app_id} not found" in str(exc_info.value)
# Verify no workflow tool was created # Verify no workflow tool was created
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -344,7 +342,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0 assert tool_count == 0
def test_create_workflow_tool_invalid_parameters_error( def test_create_workflow_tool_invalid_parameters_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation fails when parameters are invalid. Test workflow tool creation fails when parameters are invalid.
@ -387,10 +385,9 @@ class TestWorkflowToolManageService:
assert "validation error" in str(exc_info.value).lower() assert "validation error" in str(exc_info.value).lower()
# Verify no workflow tool was created # Verify no workflow tool was created
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -400,7 +397,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0 assert tool_count == 0
def test_create_workflow_tool_duplicate_app_id_error( def test_create_workflow_tool_duplicate_app_id_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation fails when app_id already exists. Test workflow tool creation fails when app_id already exists.
@ -450,10 +447,9 @@ class TestWorkflowToolManageService:
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created # Verify only one tool was created
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -463,7 +459,7 @@ class TestWorkflowToolManageService:
assert tool_count == 1 assert tool_count == 1
def test_create_workflow_tool_workflow_not_found_error( def test_create_workflow_tool_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation fails when app has no workflow. Test workflow tool creation fails when app has no workflow.
@ -481,10 +477,9 @@ class TestWorkflowToolManageService:
) )
# Remove workflow reference from app # Remove workflow reference from app
from extensions.ext_database import db
app.workflow_id = None app.workflow_id = None
db.session.commit() db_session_with_containers.commit()
# Attempt to create workflow tool for app without workflow # Attempt to create workflow tool for app without workflow
tool_parameters = self._create_test_workflow_tool_parameters() tool_parameters = self._create_test_workflow_tool_parameters()
@ -505,7 +500,7 @@ class TestWorkflowToolManageService:
# Verify no workflow tool was created # Verify no workflow tool was created
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -515,7 +510,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0 assert tool_count == 0
def test_create_workflow_tool_human_input_node_error( def test_create_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation fails when workflow contains human input nodes. Test workflow tool creation fails when workflow contains human input nodes.
@ -558,10 +553,8 @@ class TestWorkflowToolManageService:
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -570,7 +563,9 @@ class TestWorkflowToolManageService:
assert tool_count == 0 assert tool_count == 0
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): def test_update_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful workflow tool update with valid parameters. Test successful workflow tool update with valid parameters.
@ -603,10 +598,9 @@ class TestWorkflowToolManageService:
) )
# Get the created tool # Get the created tool
from extensions.ext_database import db
created_tool = ( created_tool = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id, WorkflowToolProvider.app_id == app.id,
@ -641,7 +635,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
# Verify database state was updated # Verify database state was updated
db.session.refresh(created_tool) db_session_with_containers.refresh(created_tool)
assert created_tool is not None assert created_tool is not None
assert created_tool.name == updated_tool_name assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label assert created_tool.label == updated_tool_label
@ -658,7 +652,7 @@ class TestWorkflowToolManageService:
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
def test_update_workflow_tool_human_input_node_error( def test_update_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool update fails when workflow contains human input nodes. Test workflow tool update fails when workflow contains human input nodes.
@ -689,10 +683,8 @@ class TestWorkflowToolManageService:
parameters=initial_tool_parameters, parameters=initial_tool_parameters,
) )
from extensions.ext_database import db
created_tool = ( created_tool = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id, WorkflowToolProvider.app_id == app.id,
@ -712,7 +704,7 @@ class TestWorkflowToolManageService:
] ]
} }
) )
db.session.commit() db_session_with_containers.commit()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolManageService.update_workflow_tool( WorkflowToolManageService.update_workflow_tool(
@ -728,10 +720,12 @@ class TestWorkflowToolManageService:
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
db.session.refresh(created_tool) db_session_with_containers.refresh(created_tool)
assert created_tool.name == original_name assert created_tool.name == original_name
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): def test_update_workflow_tool_not_found_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test workflow tool update fails when tool does not exist. Test workflow tool update fails when tool does not exist.
@ -768,10 +762,9 @@ class TestWorkflowToolManageService:
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
# Verify no workflow tool was created # Verify no workflow tool was created
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
) )
@ -781,7 +774,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0 assert tool_count == 0
def test_update_workflow_tool_same_name_success( def test_update_workflow_tool_same_name_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool update succeeds when keeping the same name. Test workflow tool update succeeds when keeping the same name.
@ -813,10 +806,9 @@ class TestWorkflowToolManageService:
) )
# Get the created tool # Get the created tool
from extensions.ext_database import db
created_tool = ( created_tool = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id, WorkflowToolProvider.app_id == app.id,
@ -840,12 +832,12 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
# Verify tool still exists with the same name # Verify tool still exists with the same name
db.session.refresh(created_tool) db_session_with_containers.refresh(created_tool)
assert created_tool.name == first_tool_name assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None assert created_tool.updated_at is not None
def test_create_workflow_tool_with_file_parameter_default( def test_create_workflow_tool_with_file_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation with FILE parameter having a file object as default. Test workflow tool creation with FILE parameter having a file object as default.
@ -916,7 +908,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
def test_create_workflow_tool_with_files_parameter_default( def test_create_workflow_tool_with_files_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default. Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
@ -991,7 +983,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"} assert result == {"result": "success"}
def test_create_workflow_tool_db_commit_before_validation( def test_create_workflow_tool_db_commit_before_validation(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test that database commit happens before validation, causing DB pollution on validation failure. Test that database commit happens before validation, causing DB pollution on validation failure.
@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService:
# Verify the tool was NOT created in database # Verify the tool was NOT created in database
# This is the expected behavior (no pollution) # This is the expected behavior (no pollution)
from extensions.ext_database import db
tool_count = ( tool_count = (
db.session.query(WorkflowToolProvider) db_session_with_containers.query(WorkflowToolProvider)
.where( .where(
WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.name == tool_name, WorkflowToolProvider.name == tool_name,

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.app.app_config.entities import ( from core.app.app_config.entities import (
DatasetEntity, DatasetEntity,
@ -79,7 +80,7 @@ class TestWorkflowConverter:
mock_config.app_model_config_dict = {} mock_config.app_model_config_dict = {}
return mock_config return mock_config
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -100,18 +101,16 @@ class TestWorkflowConverter:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
@ -122,15 +121,17 @@ class TestWorkflowConverter:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account): def _create_test_app(
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account
):
""" """
Helper method to create a test app for testing. Helper method to create a test app for testing.
@ -163,10 +164,8 @@ class TestWorkflowConverter:
updated_by=account.id, updated_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(app)
db_session_with_containers.commit()
db.session.add(app)
db.session.commit()
# Create app model config # Create app model config
app_model_config = AppModelConfig( app_model_config = AppModelConfig(
@ -177,16 +176,16 @@ class TestWorkflowConverter:
created_by=account.id, created_by=account.id,
updated_by=account.id, updated_by=account.id,
) )
db.session.add(app_model_config) db_session_with_containers.add(app_model_config)
db.session.commit() db_session_with_containers.commit()
# Link app model config to app # Link app model config to app
app.app_model_config_id = app_model_config.id app.app_model_config_id = app_model_config.id
db.session.commit() db_session_with_containers.commit()
return app return app
def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies): def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
Test successful conversion of app to workflow. Test successful conversion of app to workflow.
@ -225,19 +224,18 @@ class TestWorkflowConverter:
assert new_app.created_by == account.id assert new_app.created_by == account.id
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(new_app) db_session_with_containers.refresh(new_app)
assert new_app.id is not None assert new_app.id is not None
# Verify workflow was created # Verify workflow was created
workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first() workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first()
assert workflow is not None assert workflow is not None
assert workflow.tenant_id == app.tenant_id assert workflow.tenant_id == app.tenant_id
assert workflow.type == "chat" assert workflow.type == "chat"
def test_convert_to_workflow_without_app_model_config_error( def test_convert_to_workflow_without_app_model_config_error(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling when app model config is missing. Test error handling when app model config is missing.
@ -270,16 +268,14 @@ class TestWorkflowConverter:
updated_by=account.id, updated_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(app)
db_session_with_containers.commit()
db.session.add(app)
db.session.commit()
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
workflow_converter = WorkflowConverter() workflow_converter = WorkflowConverter()
# Check initial state # Check initial state
initial_workflow_count = db.session.query(Workflow).count() initial_workflow_count = db_session_with_containers.query(Workflow).count()
with pytest.raises(ValueError, match="App model config is required"): with pytest.raises(ValueError, match="App model config is required"):
workflow_converter.convert_to_workflow( workflow_converter.convert_to_workflow(
@ -294,12 +290,12 @@ class TestWorkflowConverter:
# Verify database state remains unchanged # Verify database state remains unchanged
# The workflow creation happens in convert_app_model_config_to_workflow # The workflow creation happens in convert_app_model_config_to_workflow
# which is called before the app_model_config check, so we need to clean up # which is called before the app_model_config check, so we need to clean up
db.session.rollback() db_session_with_containers.rollback()
final_workflow_count = db.session.query(Workflow).count() final_workflow_count = db_session_with_containers.query(Workflow).count()
assert final_workflow_count == initial_workflow_count assert final_workflow_count == initial_workflow_count
def test_convert_app_model_config_to_workflow_success( def test_convert_app_model_config_to_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion of app model config to workflow. Test successful conversion of app model config to workflow.
@ -356,16 +352,17 @@ class TestWorkflowConverter:
assert answer_node["id"] == "answer" assert answer_node["id"] == "answer"
# Verify database state # Verify database state
from extensions.ext_database import db
db.session.refresh(workflow) db_session_with_containers.refresh(workflow)
assert workflow.id is not None assert workflow.id is not None
# Verify features were set # Verify features were set
features = json.loads(workflow._features) if workflow._features else {} features = json.loads(workflow._features) if workflow._features else {}
assert isinstance(features, dict) assert isinstance(features, dict)
def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies): def test_convert_to_start_node_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful conversion to start node. Test successful conversion to start node.
@ -410,7 +407,9 @@ class TestWorkflowConverter:
assert second_variable["label"] == "Number Input" assert second_variable["label"] == "Number Input"
assert second_variable["type"] == "number" assert second_variable["type"] == "number"
def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies): def test_convert_to_http_request_node_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful conversion to HTTP request node. Test successful conversion to HTTP request node.
@ -436,10 +435,8 @@ class TestWorkflowConverter:
api_endpoint="https://api.example.com/test", api_endpoint="https://api.example.com/test",
) )
from extensions.ext_database import db db_session_with_containers.add(api_based_extension)
db_session_with_containers.commit()
db.session.add(api_based_extension)
db.session.commit()
# Mock encrypter # Mock encrypter
mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key" mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key"
@ -489,7 +486,7 @@ class TestWorkflowConverter:
assert external_data_variable_node_mapping["external_data"] == code_node["id"] assert external_data_variable_node_mapping["external_data"] == code_node["id"]
def test_convert_to_knowledge_retrieval_node_success( def test_convert_to_knowledge_retrieval_node_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful conversion to knowledge retrieval node. Test successful conversion to knowledge retrieval node.

View File

@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
@ -31,7 +31,9 @@ class TestAddDocumentToIndexTask:
"index_processor": mock_processor, "index_processor": mock_processor,
} }
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_dataset_and_document(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Helper method to create a test dataset and document for testing. Helper method to create a test dataset and document for testing.
@ -51,15 +53,15 @@ class TestAddDocumentToIndexTask:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -68,8 +70,8 @@ class TestAddDocumentToIndexTask:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Create dataset # Create dataset
dataset = Dataset( dataset = Dataset(
@ -81,8 +83,8 @@ class TestAddDocumentToIndexTask:
indexing_technique="high_quality", indexing_technique="high_quality",
created_by=account.id, created_by=account.id,
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Create document # Create document
document = Document( document = Document(
@ -99,15 +101,15 @@ class TestAddDocumentToIndexTask:
enabled=True, enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_form=IndexStructureType.PARAGRAPH_INDEX,
) )
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property works correctly # Refresh dataset to ensure doc_form property works correctly
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
return dataset, document return dataset, document
def _create_test_segments(self, db_session_with_containers, document, dataset): def _create_test_segments(self, db_session_with_containers: Session, document, dataset):
""" """
Helper method to create test document segments. Helper method to create test document segments.
@ -138,13 +140,15 @@ class TestAddDocumentToIndexTask:
status="completed", status="completed",
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment) db_session_with_containers.add(segment)
segments.append(segment) segments.append(segment)
db.session.commit() db_session_with_containers.commit()
return segments return segments
def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): def test_add_document_to_index_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful document indexing with paragraph index type. Test successful document indexing with paragraph index type.
@ -180,9 +184,9 @@ class TestAddDocumentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once() mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify database state changes # Verify database state changes
db.session.refresh(document) db_session_with_containers.refresh(document)
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is True assert segment.enabled is True
assert segment.disabled_at is None assert segment.disabled_at is None
assert segment.disabled_by is None assert segment.disabled_by is None
@ -191,7 +195,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_different_index_type( def test_add_document_to_index_with_different_index_type(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test document indexing with different index types. Test document indexing with different index types.
@ -209,10 +213,10 @@ class TestAddDocumentToIndexTask:
# Update document to use different index type # Update document to use different index type
document.doc_form = IndexStructureType.QA_INDEX document.doc_form = IndexStructureType.QA_INDEX
db.session.commit() db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document # Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
# Create segments # Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset) segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -237,9 +241,9 @@ class TestAddDocumentToIndexTask:
assert len(documents) == 3 assert len(documents) == 3
# Verify database state changes # Verify database state changes
db.session.refresh(document) db_session_with_containers.refresh(document)
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is True assert segment.enabled is True
assert segment.disabled_at is None assert segment.disabled_at is None
assert segment.disabled_by is None assert segment.disabled_by is None
@ -248,7 +252,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_document_not_found( def test_add_document_to_index_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling of non-existent document. Test handling of non-existent document.
@ -275,7 +279,7 @@ class TestAddDocumentToIndexTask:
# because indexing_cache_key is not defined in that case # because indexing_cache_key is not defined in that case
def test_add_document_to_index_invalid_indexing_status( def test_add_document_to_index_invalid_indexing_status(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling of document with invalid indexing status. Test handling of document with invalid indexing status.
@ -294,7 +298,7 @@ class TestAddDocumentToIndexTask:
# Set invalid indexing status # Set invalid indexing status
document.indexing_status = "processing" document.indexing_status = "processing"
db.session.commit() db_session_with_containers.commit()
# Act: Execute the task # Act: Execute the task
add_document_to_index_task(document.id) add_document_to_index_task(document.id)
@ -304,7 +308,7 @@ class TestAddDocumentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called() mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_add_document_to_index_dataset_not_found( def test_add_document_to_index_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling when document's dataset doesn't exist. Test handling when document's dataset doesn't exist.
@ -326,14 +330,14 @@ class TestAddDocumentToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300) redis_client.set(indexing_cache_key, "processing", ex=300)
# Delete the dataset to simulate dataset not found scenario # Delete the dataset to simulate dataset not found scenario
db.session.delete(dataset) db_session_with_containers.delete(dataset)
db.session.commit() db_session_with_containers.commit()
# Act: Execute the task # Act: Execute the task
add_document_to_index_task(document.id) add_document_to_index_task(document.id)
# Assert: Verify error handling # Assert: Verify error handling
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.enabled is False assert document.enabled is False
assert document.indexing_status == "error" assert document.indexing_status == "error"
assert document.error is not None assert document.error is not None
@ -348,7 +352,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_parent_child_structure( def test_add_document_to_index_with_parent_child_structure(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test document indexing with parent-child structure. Test document indexing with parent-child structure.
@ -367,10 +371,10 @@ class TestAddDocumentToIndexTask:
# Update document to use parent-child index type # Update document to use parent-child index type
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit() db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document # Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
# Create segments with mock child chunks # Create segments with mock child chunks
segments = self._create_test_segments(db_session_with_containers, document, dataset) segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -413,9 +417,9 @@ class TestAddDocumentToIndexTask:
assert len(doc.children) == 2 # Each document has 2 children assert len(doc.children) == 2 # Each document has 2 children
# Verify database state changes # Verify database state changes
db.session.refresh(document) db_session_with_containers.refresh(document)
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is True assert segment.enabled is True
assert segment.disabled_at is None assert segment.disabled_at is None
assert segment.disabled_by is None assert segment.disabled_by is None
@ -424,7 +428,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_already_enabled_segments( def test_add_document_to_index_with_already_enabled_segments(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test document indexing when segments are already enabled. Test document indexing when segments are already enabled.
@ -459,10 +463,10 @@ class TestAddDocumentToIndexTask:
status="completed", status="completed",
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment) db_session_with_containers.add(segment)
segments.append(segment) segments.append(segment)
db.session.commit() db_session_with_containers.commit()
# Set up Redis cache key # Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing" indexing_cache_key = f"document_{document.id}_indexing"
@ -488,7 +492,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_auto_disable_log_deletion( def test_add_document_to_index_auto_disable_log_deletion(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test that auto disable logs are properly deleted during indexing. Test that auto disable logs are properly deleted during indexing.
@ -515,10 +519,10 @@ class TestAddDocumentToIndexTask:
document_id=document.id, document_id=document.id,
) )
log_entry.id = str(fake.uuid4()) log_entry.id = str(fake.uuid4())
db.session.add(log_entry) db_session_with_containers.add(log_entry)
auto_disable_logs.append(log_entry) auto_disable_logs.append(log_entry)
db.session.commit() db_session_with_containers.commit()
# Set up Redis cache key # Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing" indexing_cache_key = f"document_{document.id}_indexing"
@ -526,7 +530,9 @@ class TestAddDocumentToIndexTask:
# Verify logs exist before processing # Verify logs exist before processing
existing_logs = ( existing_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
) )
assert len(existing_logs) == 2 assert len(existing_logs) == 2
@ -535,7 +541,9 @@ class TestAddDocumentToIndexTask:
# Assert: Verify auto disable logs were deleted # Assert: Verify auto disable logs were deleted
remaining_logs = ( remaining_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
) )
assert len(remaining_logs) == 0 assert len(remaining_logs) == 0
@ -547,14 +555,14 @@ class TestAddDocumentToIndexTask:
# Verify segments were enabled # Verify segments were enabled
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is True assert segment.enabled is True
# Verify redis cache was cleared # Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_general_exception_handling( def test_add_document_to_index_general_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test general exception handling during indexing process. Test general exception handling during indexing process.
@ -584,7 +592,7 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id) add_document_to_index_task(document.id)
# Assert: Verify error handling # Assert: Verify error handling
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.enabled is False assert document.enabled is False
assert document.indexing_status == "error" assert document.indexing_status == "error"
assert document.error is not None assert document.error is not None
@ -593,14 +601,14 @@ class TestAddDocumentToIndexTask:
# Verify segments were not enabled due to error # Verify segments were not enabled due to error
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is False # Should remain disabled due to error assert segment.enabled is False # Should remain disabled due to error
# Verify redis cache was still cleared despite error # Verify redis cache was still cleared despite error
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_segment_filtering_edge_cases( def test_add_document_to_index_segment_filtering_edge_cases(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test segment filtering with various edge cases. Test segment filtering with various edge cases.
@ -638,7 +646,7 @@ class TestAddDocumentToIndexTask:
status="completed", status="completed",
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment1) db_session_with_containers.add(segment1)
segments.append(segment1) segments.append(segment1)
# Segment 2: Should be processed (enabled=True, status="completed") # Segment 2: Should be processed (enabled=True, status="completed")
@ -658,7 +666,7 @@ class TestAddDocumentToIndexTask:
status="completed", status="completed",
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment2) db_session_with_containers.add(segment2)
segments.append(segment2) segments.append(segment2)
# Segment 3: Should NOT be processed (enabled=False, status="processing") # Segment 3: Should NOT be processed (enabled=False, status="processing")
@ -677,7 +685,7 @@ class TestAddDocumentToIndexTask:
status="processing", # Not completed status="processing", # Not completed
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment3) db_session_with_containers.add(segment3)
segments.append(segment3) segments.append(segment3)
# Segment 4: Should be processed (enabled=False, status="completed") # Segment 4: Should be processed (enabled=False, status="completed")
@ -696,10 +704,10 @@ class TestAddDocumentToIndexTask:
status="completed", status="completed",
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment4) db_session_with_containers.add(segment4)
segments.append(segment4) segments.append(segment4)
db.session.commit() db_session_with_containers.commit()
# Set up Redis cache key # Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing" indexing_cache_key = f"document_{document.id}_indexing"
@ -728,11 +736,11 @@ class TestAddDocumentToIndexTask:
assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3 assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3
# Verify database state changes # Verify database state changes
db.session.refresh(document) db_session_with_containers.refresh(document)
db.session.refresh(segment1) db_session_with_containers.refresh(segment1)
db.session.refresh(segment2) db_session_with_containers.refresh(segment2)
db.session.refresh(segment3) db_session_with_containers.refresh(segment3)
db.session.refresh(segment4) db_session_with_containers.refresh(segment4)
# All segments should be enabled because the task updates ALL segments for the document # All segments should be enabled because the task updates ALL segments for the document
assert segment1.enabled is True assert segment1.enabled is True
@ -744,7 +752,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_comprehensive_error_scenarios( def test_add_document_to_index_comprehensive_error_scenarios(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test comprehensive error scenarios and recovery. Test comprehensive error scenarios and recovery.
@ -779,7 +787,7 @@ class TestAddDocumentToIndexTask:
document.indexing_status = "completed" document.indexing_status = "completed"
document.error = None document.error = None
document.disabled_at = None document.disabled_at = None
db.session.commit() db_session_with_containers.commit()
# Set up Redis cache key # Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing" indexing_cache_key = f"document_{document.id}_indexing"
@ -789,7 +797,7 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id) add_document_to_index_task(document.id)
# Assert: Verify consistent error handling # Assert: Verify consistent error handling
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.enabled is False, f"Document should be disabled for {error_name}" assert document.enabled is False, f"Document should be disabled for {error_name}"
assert document.indexing_status == "error", f"Document status should be error for {error_name}" assert document.indexing_status == "error", f"Document status should be error for {error_name}"
assert document.error is not None, f"Error should be recorded for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}"
@ -798,7 +806,7 @@ class TestAddDocumentToIndexTask:
# Verify segments remain disabled due to error # Verify segments remain disabled due to error
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is False, f"Segments should remain disabled for {error_name}" assert segment.enabled is False, f"Segments should remain disabled for {error_name}"
# Verify redis cache was still cleared despite error # Verify redis cache was still cleared despite error

View File

@ -11,8 +11,8 @@ from unittest.mock import Mock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -49,7 +49,7 @@ class TestBatchCleanDocumentTask:
"get_image_ids": mock_get_image_ids, "get_image_ids": mock_get_image_ids,
} }
def _create_test_account(self, db_session_with_containers): def _create_test_account(self, db_session_with_containers: Session):
""" """
Helper method to create a test account for testing. Helper method to create a test account for testing.
@ -69,16 +69,16 @@ class TestBatchCleanDocumentTask:
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -87,15 +87,15 @@ class TestBatchCleanDocumentTask:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account return account
def _create_test_dataset(self, db_session_with_containers, account): def _create_test_dataset(self, db_session_with_containers: Session, account):
""" """
Helper method to create a test dataset for testing. Helper method to create a test dataset for testing.
@ -119,12 +119,12 @@ class TestBatchCleanDocumentTask:
embedding_model_provider="openai", embedding_model_provider="openai",
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
def _create_test_document(self, db_session_with_containers, dataset, account): def _create_test_document(self, db_session_with_containers: Session, dataset, account):
""" """
Helper method to create a test document for testing. Helper method to create a test document for testing.
@ -153,12 +153,12 @@ class TestBatchCleanDocumentTask:
doc_form="text_model", doc_form="text_model",
) )
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
return document return document
def _create_test_document_segment(self, db_session_with_containers, document, account): def _create_test_document_segment(self, db_session_with_containers: Session, document, account):
""" """
Helper method to create a test document segment for testing. Helper method to create a test document segment for testing.
@ -186,12 +186,12 @@ class TestBatchCleanDocumentTask:
status="completed", status="completed",
) )
db.session.add(segment) db_session_with_containers.add(segment)
db.session.commit() db_session_with_containers.commit()
return segment return segment
def _create_test_upload_file(self, db_session_with_containers, account): def _create_test_upload_file(self, db_session_with_containers: Session, account):
""" """
Helper method to create a test upload file for testing. Helper method to create a test upload file for testing.
@ -220,13 +220,13 @@ class TestBatchCleanDocumentTask:
used=False, used=False,
) )
db.session.add(upload_file) db_session_with_containers.add(upload_file)
db.session.commit() db_session_with_containers.commit()
return upload_file return upload_file
def test_batch_clean_document_task_successful_cleanup( def test_batch_clean_document_task_successful_cleanup(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful cleanup of documents with segments and files. Test successful cleanup of documents with segments and files.
@ -245,7 +245,7 @@ class TestBatchCleanDocumentTask:
# Update document to reference the upload file # Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit() db_session_with_containers.commit()
# Store original IDs for verification # Store original IDs for verification
document_id = document.id document_id = document.id
@ -261,18 +261,18 @@ class TestBatchCleanDocumentTask:
# The task should have processed the segment and cleaned up the database # The task should have processed the segment and cleaned up the database
# Verify database cleanup # Verify database cleanup
db.session.commit() # Ensure all changes are committed db_session_with_containers.commit() # Ensure all changes are committed
# Check that segment is deleted # Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
# Check that upload file is deleted # Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
def test_batch_clean_document_task_with_image_files( def test_batch_clean_document_task_with_image_files(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup of documents containing image references. Test cleanup of documents containing image references.
@ -300,8 +300,8 @@ class TestBatchCleanDocumentTask:
status="completed", status="completed",
) )
db.session.add(segment) db_session_with_containers.add(segment)
db.session.commit() db_session_with_containers.commit()
# Store original IDs for verification # Store original IDs for verification
segment_id = segment.id segment_id = segment.id
@ -313,17 +313,17 @@ class TestBatchCleanDocumentTask:
) )
# Verify database cleanup # Verify database cleanup
db.session.commit() db_session_with_containers.commit()
# Check that segment is deleted # Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
# Verify that the task completed successfully by checking the log output # Verify that the task completed successfully by checking the log output
# The task should have processed the segment and cleaned up the database # The task should have processed the segment and cleaned up the database
def test_batch_clean_document_task_no_segments( def test_batch_clean_document_task_no_segments(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup when document has no segments. Test cleanup when document has no segments.
@ -339,7 +339,7 @@ class TestBatchCleanDocumentTask:
# Update document to reference the upload file # Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit() db_session_with_containers.commit()
# Store original IDs for verification # Store original IDs for verification
document_id = document.id document_id = document.id
@ -354,21 +354,21 @@ class TestBatchCleanDocumentTask:
# Since there are no segments, the task should handle this gracefully # Since there are no segments, the task should handle this gracefully
# Verify database cleanup # Verify database cleanup
db.session.commit() db_session_with_containers.commit()
# Check that upload file is deleted # Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
# Verify database cleanup # Verify database cleanup
db.session.commit() db_session_with_containers.commit()
# Check that upload file is deleted # Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
def test_batch_clean_document_task_dataset_not_found( def test_batch_clean_document_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup when dataset is not found. Test cleanup when dataset is not found.
@ -386,8 +386,8 @@ class TestBatchCleanDocumentTask:
dataset_id = dataset.id dataset_id = dataset.id
# Delete the dataset to simulate not found scenario # Delete the dataset to simulate not found scenario
db.session.delete(dataset) db_session_with_containers.delete(dataset)
db.session.commit() db_session_with_containers.commit()
# Execute the task with non-existent dataset # Execute the task with non-existent dataset
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
@ -399,14 +399,14 @@ class TestBatchCleanDocumentTask:
mock_external_service_dependencies["storage"].delete.assert_not_called() mock_external_service_dependencies["storage"].delete.assert_not_called()
# Verify that no database cleanup occurred # Verify that no database cleanup occurred
db.session.commit() db_session_with_containers.commit()
# Document should still exist since cleanup failed # Document should still exist since cleanup failed
existing_document = db.session.query(Document).filter_by(id=document_id).first() existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
assert existing_document is not None assert existing_document is not None
def test_batch_clean_document_task_storage_cleanup_failure( def test_batch_clean_document_task_storage_cleanup_failure(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup when storage operations fail. Test cleanup when storage operations fail.
@ -423,7 +423,7 @@ class TestBatchCleanDocumentTask:
# Update document to reference the upload file # Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit() db_session_with_containers.commit()
# Store original IDs for verification # Store original IDs for verification
document_id = document.id document_id = document.id
@ -442,18 +442,18 @@ class TestBatchCleanDocumentTask:
# The task should continue processing even when storage operations fail # The task should continue processing even when storage operations fail
# Verify database cleanup still occurred despite storage failure # Verify database cleanup still occurred despite storage failure
db.session.commit() db_session_with_containers.commit()
# Check that segment is deleted from database # Check that segment is deleted from database
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
# Check that upload file is deleted from database # Check that upload file is deleted from database
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
def test_batch_clean_document_task_multiple_documents( def test_batch_clean_document_task_multiple_documents(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup of multiple documents in a single batch operation. Test cleanup of multiple documents in a single batch operation.
@ -482,7 +482,7 @@ class TestBatchCleanDocumentTask:
segments.append(segment) segments.append(segment)
upload_files.append(upload_file) upload_files.append(upload_file)
db.session.commit() db_session_with_containers.commit()
# Store original IDs for verification # Store original IDs for verification
document_ids = [doc.id for doc in documents] document_ids = [doc.id for doc in documents]
@ -498,20 +498,20 @@ class TestBatchCleanDocumentTask:
# The task should process all documents and clean up all associated resources # The task should process all documents and clean up all associated resources
# Verify database cleanup for all resources # Verify database cleanup for all resources
db.session.commit() db_session_with_containers.commit()
# Check that all segments are deleted # Check that all segments are deleted
for segment_id in segment_ids: for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
# Check that all upload files are deleted # Check that all upload files are deleted
for file_id in file_ids: for file_id in file_ids:
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
def test_batch_clean_document_task_different_doc_forms( def test_batch_clean_document_task_different_doc_forms(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup with different document form types. Test cleanup with different document form types.
@ -527,12 +527,12 @@ class TestBatchCleanDocumentTask:
for doc_form in doc_forms: for doc_form in doc_forms:
dataset = self._create_test_dataset(db_session_with_containers, account) dataset = self._create_test_dataset(db_session_with_containers, account)
db.session.commit() db_session_with_containers.commit()
document = self._create_test_document(db_session_with_containers, dataset, account) document = self._create_test_document(db_session_with_containers, dataset, account)
# Update document doc_form # Update document doc_form
document.doc_form = doc_form document.doc_form = doc_form
db.session.commit() db_session_with_containers.commit()
segment = self._create_test_document_segment(db_session_with_containers, document, account) segment = self._create_test_document_segment(db_session_with_containers, document, account)
@ -549,20 +549,20 @@ class TestBatchCleanDocumentTask:
# The task should handle different document forms correctly # The task should handle different document forms correctly
# Verify database cleanup # Verify database cleanup
db.session.commit() db_session_with_containers.commit()
# Check that segment is deleted # Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
except Exception as e: except Exception as e:
# If the task fails due to external service issues (e.g., plugin daemon), # If the task fails due to external service issues (e.g., plugin daemon),
# we should still verify that the database state is consistent # we should still verify that the database state is consistent
# This is a common scenario in test environments where external services may not be available # This is a common scenario in test environments where external services may not be available
db.session.commit() db_session_with_containers.commit()
# Check if the segment still exists (task may have failed before deletion) # Check if the segment still exists (task may have failed before deletion)
existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
if existing_segment is not None: if existing_segment is not None:
# If segment still exists, the task failed before deletion # If segment still exists, the task failed before deletion
# This is acceptable in test environments with external service issues # This is acceptable in test environments with external service issues
@ -572,7 +572,7 @@ class TestBatchCleanDocumentTask:
pass pass
def test_batch_clean_document_task_large_batch_performance( def test_batch_clean_document_task_large_batch_performance(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test cleanup performance with a large batch of documents. Test cleanup performance with a large batch of documents.
@ -604,7 +604,7 @@ class TestBatchCleanDocumentTask:
segments.append(segment) segments.append(segment)
upload_files.append(upload_file) upload_files.append(upload_file)
db.session.commit() db_session_with_containers.commit()
# Store original IDs for verification # Store original IDs for verification
document_ids = [doc.id for doc in documents] document_ids = [doc.id for doc in documents]
@ -629,20 +629,20 @@ class TestBatchCleanDocumentTask:
# The task should handle large batches efficiently # The task should handle large batches efficiently
# Verify database cleanup for all resources # Verify database cleanup for all resources
db.session.commit() db_session_with_containers.commit()
# Check that all segments are deleted # Check that all segments are deleted
for segment_id in segment_ids: for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
# Check that all upload files are deleted # Check that all upload files are deleted
for file_id in file_ids: for file_id in file_ids:
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
def test_batch_clean_document_task_integration_with_real_database( def test_batch_clean_document_task_integration_with_real_database(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test full integration with real database operations. Test full integration with real database operations.
@ -683,12 +683,12 @@ class TestBatchCleanDocumentTask:
# Add all to database # Add all to database
for segment in segments: for segment in segments:
db.session.add(segment) db_session_with_containers.add(segment)
db.session.commit() db_session_with_containers.commit()
# Verify initial state # Verify initial state
assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
# Store original IDs for verification # Store original IDs for verification
document_id = document.id document_id = document.id
@ -704,17 +704,17 @@ class TestBatchCleanDocumentTask:
# The task should process all segments and clean up all associated resources # The task should process all segments and clean up all associated resources
# Verify database cleanup # Verify database cleanup
db.session.commit() db_session_with_containers.commit()
# Check that all segments are deleted # Check that all segments are deleted
for segment_id in segment_ids: for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None assert deleted_segment is None
# Check that upload file is deleted # Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None assert deleted_file is None
# Verify final database state # Verify final database state
assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db.session.query(UploadFile).filter_by(id=file_id).first() is None assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None

View File

@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -29,20 +30,19 @@ class TestBatchCreateSegmentToIndexTask:
"""Integration tests for batch_create_segment_to_index_task using testcontainers.""" """Integration tests for batch_create_segment_to_index_task using testcontainers."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers): def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before each test to ensure isolation.""" """Clean up database before each test to ensure isolation."""
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
# Clear all test data # Clear all test data
db.session.query(DocumentSegment).delete() db_session_with_containers.query(DocumentSegment).delete()
db.session.query(Document).delete() db_session_with_containers.query(Document).delete()
db.session.query(Dataset).delete() db_session_with_containers.query(Dataset).delete()
db.session.query(UploadFile).delete() db_session_with_containers.query(UploadFile).delete()
db.session.query(TenantAccountJoin).delete() db_session_with_containers.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete() db_session_with_containers.query(Tenant).delete()
db.session.query(Account).delete() db_session_with_containers.query(Account).delete()
db.session.commit() db_session_with_containers.commit()
# Clear Redis cache # Clear Redis cache
redis_client.flushdb() redis_client.flushdb()
@ -75,7 +75,7 @@ class TestBatchCreateSegmentToIndexTask:
"embedding_model": mock_embedding_model, "embedding_model": mock_embedding_model,
} }
def _create_test_account_and_tenant(self, db_session_with_containers): def _create_test_account_and_tenant(self, db_session_with_containers: Session):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -95,18 +95,16 @@ class TestBatchCreateSegmentToIndexTask:
status="active", status="active",
) )
from extensions.ext_database import db db_session_with_containers.add(account)
db_session_with_containers.commit()
db.session.add(account)
db.session.commit()
# Create tenant for the account # Create tenant for the account
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -115,15 +113,15 @@ class TestBatchCreateSegmentToIndexTask:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
def _create_test_dataset(self, db_session_with_containers, account, tenant): def _create_test_dataset(self, db_session_with_containers: Session, account, tenant):
""" """
Helper method to create a test dataset for testing. Helper method to create a test dataset for testing.
@ -148,14 +146,12 @@ class TestBatchCreateSegmentToIndexTask:
created_by=account.id, created_by=account.id,
) )
from extensions.ext_database import db db_session_with_containers.add(dataset)
db_session_with_containers.commit()
db.session.add(dataset)
db.session.commit()
return dataset return dataset
def _create_test_document(self, db_session_with_containers, account, tenant, dataset): def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset):
""" """
Helper method to create a test document for testing. Helper method to create a test document for testing.
@ -186,14 +182,12 @@ class TestBatchCreateSegmentToIndexTask:
word_count=0, word_count=0,
) )
from extensions.ext_database import db db_session_with_containers.add(document)
db_session_with_containers.commit()
db.session.add(document)
db.session.commit()
return document return document
def _create_test_upload_file(self, db_session_with_containers, account, tenant): def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant):
""" """
Helper method to create a test upload file for testing. Helper method to create a test upload file for testing.
@ -221,10 +215,8 @@ class TestBatchCreateSegmentToIndexTask:
used=False, used=False,
) )
from extensions.ext_database import db db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
db.session.add(upload_file)
db.session.commit()
return upload_file return upload_file
@ -252,7 +244,7 @@ class TestBatchCreateSegmentToIndexTask:
return csv_content return csv_content
def test_batch_create_segment_to_index_task_success_text_model( def test_batch_create_segment_to_index_task_success_text_model(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful batch creation of segments for text model documents. Test successful batch creation of segments for text model documents.
@ -293,11 +285,10 @@ class TestBatchCreateSegmentToIndexTask:
) )
# Verify results # Verify results
from extensions.ext_database import db
# Check that segments were created # Check that segments were created
segments = ( segments = (
db.session.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id) .filter_by(document_id=document.id)
.order_by(DocumentSegment.position) .order_by(DocumentSegment.position)
.all() .all()
@ -316,7 +307,7 @@ class TestBatchCreateSegmentToIndexTask:
assert segment.answer is None # text_model doesn't have answers assert segment.answer is None # text_model doesn't have answers
# Check that document word count was updated # Check that document word count was updated
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.word_count > 0 assert document.word_count > 0
# Verify vector service was called # Verify vector service was called
@ -331,7 +322,7 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"completed" assert cache_value == b"completed"
def test_batch_create_segment_to_index_task_dataset_not_found( def test_batch_create_segment_to_index_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test task failure when dataset does not exist. Test task failure when dataset does not exist.
@ -370,17 +361,16 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error" assert cache_value == b"error"
# Verify no segments were created (since dataset doesn't exist) # Verify no segments were created (since dataset doesn't exist)
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all() segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0 assert len(segments) == 0
# Verify no documents were modified # Verify no documents were modified
documents = db.session.query(Document).all() documents = db_session_with_containers.query(Document).all()
assert len(documents) == 0 assert len(documents) == 0
def test_batch_create_segment_to_index_task_document_not_found( def test_batch_create_segment_to_index_task_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test task failure when document does not exist. Test task failure when document does not exist.
@ -419,18 +409,17 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error" assert cache_value == b"error"
# Verify no segments were created # Verify no segments were created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all() segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0 assert len(segments) == 0
# Verify dataset remains unchanged (no segments were added to the dataset) # Verify dataset remains unchanged (no segments were added to the dataset)
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(segments_for_dataset) == 0 assert len(segments_for_dataset) == 0
def test_batch_create_segment_to_index_task_document_not_available( def test_batch_create_segment_to_index_task_document_not_available(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test task failure when document is not available for indexing. Test task failure when document is not available for indexing.
@ -498,11 +487,9 @@ class TestBatchCreateSegmentToIndexTask:
), ),
] ]
from extensions.ext_database import db
for document in test_cases: for document in test_cases:
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
# Test each unavailable document # Test each unavailable document
for document in test_cases: for document in test_cases:
@ -524,11 +511,11 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error" assert cache_value == b"error"
# Verify no segments were created # Verify no segments were created
segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all() segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
assert len(segments) == 0 assert len(segments) == 0
def test_batch_create_segment_to_index_task_upload_file_not_found( def test_batch_create_segment_to_index_task_upload_file_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test task failure when upload file does not exist. Test task failure when upload file does not exist.
@ -567,17 +554,16 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error" assert cache_value == b"error"
# Verify no segments were created # Verify no segments were created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all() segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0 assert len(segments) == 0
# Verify document remains unchanged # Verify document remains unchanged
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.word_count == 0 assert document.word_count == 0
def test_batch_create_segment_to_index_task_empty_csv_file( def test_batch_create_segment_to_index_task_empty_csv_file(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test task failure when CSV file is empty. Test task failure when CSV file is empty.
@ -619,17 +605,16 @@ class TestBatchCreateSegmentToIndexTask:
# Verify error handling # Verify error handling
# Since exception was raised, no segments should be created # Since exception was raised, no segments should be created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all() segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0 assert len(segments) == 0
# Verify document remains unchanged # Verify document remains unchanged
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.word_count == 0 assert document.word_count == 0
def test_batch_create_segment_to_index_task_position_calculation( def test_batch_create_segment_to_index_task_position_calculation(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test proper position calculation for segments when existing segments exist. Test proper position calculation for segments when existing segments exist.
@ -664,11 +649,9 @@ class TestBatchCreateSegmentToIndexTask:
) )
existing_segments.append(segment) existing_segments.append(segment)
from extensions.ext_database import db
for segment in existing_segments: for segment in existing_segments:
db.session.add(segment) db_session_with_containers.add(segment)
db.session.commit() db_session_with_containers.commit()
# Create CSV content # Create CSV content
csv_content = self._create_test_csv_content("text_model") csv_content = self._create_test_csv_content("text_model")
@ -695,7 +678,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results # Verify results
# Check that new segments were created with correct positions # Check that new segments were created with correct positions
all_segments = ( all_segments = (
db.session.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id) .filter_by(document_id=document.id)
.order_by(DocumentSegment.position) .order_by(DocumentSegment.position)
.all() .all()
@ -716,7 +699,7 @@ class TestBatchCreateSegmentToIndexTask:
assert segment.completed_at is not None assert segment.completed_at is not None
# Check that document word count was updated # Check that document word count was updated
db.session.refresh(document) db_session_with_containers.refresh(document)
assert document.word_count > 0 assert document.word_count > 0
# Verify vector service was called # Verify vector service was called

View File

@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import ( from models.dataset import (
@ -37,7 +38,7 @@ class TestCleanDatasetTask:
"""Integration tests for clean_dataset_task using testcontainers.""" """Integration tests for clean_dataset_task using testcontainers."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers): def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before each test to ensure isolation.""" """Clean up database before each test to ensure isolation."""
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -82,7 +83,7 @@ class TestCleanDatasetTask:
"index_processor": mock_index_processor, "index_processor": mock_index_processor,
} }
def _create_test_account_and_tenant(self, db_session_with_containers): def _create_test_account_and_tenant(self, db_session_with_containers: Session):
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -127,7 +128,7 @@ class TestCleanDatasetTask:
return account, tenant return account, tenant
def _create_test_dataset(self, db_session_with_containers, account, tenant): def _create_test_dataset(self, db_session_with_containers: Session, account, tenant):
""" """
Helper method to create a test dataset for testing. Helper method to create a test dataset for testing.
@ -157,7 +158,7 @@ class TestCleanDatasetTask:
return dataset return dataset
def _create_test_document(self, db_session_with_containers, account, tenant, dataset): def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset):
""" """
Helper method to create a test document for testing. Helper method to create a test document for testing.
@ -194,7 +195,7 @@ class TestCleanDatasetTask:
return document return document
def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document): def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document):
""" """
Helper method to create a test document segment for testing. Helper method to create a test document segment for testing.
@ -230,7 +231,7 @@ class TestCleanDatasetTask:
return segment return segment
def _create_test_upload_file(self, db_session_with_containers, account, tenant): def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant):
""" """
Helper method to create a test upload file for testing. Helper method to create a test upload file for testing.
@ -264,7 +265,7 @@ class TestCleanDatasetTask:
return upload_file return upload_file
def test_clean_dataset_task_success_basic_cleanup( def test_clean_dataset_task_success_basic_cleanup(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful basic dataset cleanup with minimal data. Test successful basic dataset cleanup with minimal data.
@ -325,7 +326,7 @@ class TestCleanDatasetTask:
mock_storage.delete.assert_not_called() mock_storage.delete.assert_not_called()
def test_clean_dataset_task_success_with_documents_and_segments( def test_clean_dataset_task_success_with_documents_and_segments(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful dataset cleanup with documents and segments. Test successful dataset cleanup with documents and segments.
@ -433,7 +434,7 @@ class TestCleanDatasetTask:
assert mock_storage.delete.call_count == 3 assert mock_storage.delete.call_count == 3
def test_clean_dataset_task_success_with_invalid_doc_form( def test_clean_dataset_task_success_with_invalid_doc_form(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful dataset cleanup with invalid doc_form handling. Test successful dataset cleanup with invalid doc_form handling.
@ -493,7 +494,7 @@ class TestCleanDatasetTask:
assert mock_factory.call_count == 4 assert mock_factory.call_count == 4
def test_clean_dataset_task_error_handling_and_rollback( def test_clean_dataset_task_error_handling_and_rollback(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test error handling and rollback mechanism when database operations fail. Test error handling and rollback mechanism when database operations fail.
@ -542,7 +543,7 @@ class TestCleanDatasetTask:
# This demonstrates the resilience of the cleanup process # This demonstrates the resilience of the cleanup process
def test_clean_dataset_task_with_image_file_references( def test_clean_dataset_task_with_image_file_references(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test dataset cleanup with image file references in document segments. Test dataset cleanup with image file references in document segments.
@ -634,7 +635,7 @@ class TestCleanDatasetTask:
mock_get_image_ids.assert_called_once() mock_get_image_ids.assert_called_once()
def test_clean_dataset_task_performance_with_large_dataset( def test_clean_dataset_task_performance_with_large_dataset(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test dataset cleanup performance with large amounts of data. Test dataset cleanup performance with large amounts of data.
@ -704,11 +705,9 @@ class TestCleanDatasetTask:
binding.created_at = datetime.now() binding.created_at = datetime.now()
bindings.append(binding) bindings.append(binding)
from extensions.ext_database import db db_session_with_containers.add_all(metadata_items)
db_session_with_containers.add_all(bindings)
db.session.add_all(metadata_items) db_session_with_containers.commit()
db.session.add_all(bindings)
db.session.commit()
# Measure cleanup performance # Measure cleanup performance
import time import time
@ -772,7 +771,7 @@ class TestCleanDatasetTask:
print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds") print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds")
def test_clean_dataset_task_storage_exception_handling( def test_clean_dataset_task_storage_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test dataset cleanup when storage operations fail. Test dataset cleanup when storage operations fail.
@ -838,7 +837,7 @@ class TestCleanDatasetTask:
# consistency in the database # consistency in the database
def test_clean_dataset_task_edge_cases_and_boundary_conditions( def test_clean_dataset_task_edge_cases_and_boundary_conditions(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test dataset cleanup with edge cases and boundary conditions. Test dataset cleanup with edge cases and boundary conditions.

View File

@ -13,8 +13,8 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -34,7 +34,7 @@ class TestDisableSegmentFromIndexTask:
mock_processor.clean.return_value = None mock_processor.clean.return_value = None
yield mock_processor yield mock_processor
def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]:
""" """
Helper method to create a test account and tenant for testing. Helper method to create a test account and tenant for testing.
@ -53,8 +53,8 @@ class TestDisableSegmentFromIndexTask:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
# Create tenant # Create tenant
tenant = Tenant( tenant = Tenant(
@ -62,8 +62,8 @@ class TestDisableSegmentFromIndexTask:
status="normal", status="normal",
plan="basic", plan="basic",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join with owner role # Create tenant-account join with owner role
join = TenantAccountJoin( join = TenantAccountJoin(
@ -72,15 +72,15 @@ class TestDisableSegmentFromIndexTask:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Set current tenant for account # Set current tenant for account
account.current_tenant = tenant account.current_tenant = tenant
return account, tenant return account, tenant
def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset:
""" """
Helper method to create a test dataset. Helper method to create a test dataset.
@ -101,13 +101,18 @@ class TestDisableSegmentFromIndexTask:
indexing_technique="high_quality", indexing_technique="high_quality",
created_by=account.id, created_by=account.id,
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
return dataset return dataset
def _create_test_document( def _create_test_document(
self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" self,
db_session_with_containers: Session,
dataset: Dataset,
tenant: Tenant,
account: Account,
doc_form: str = "text_model",
) -> Document: ) -> Document:
""" """
Helper method to create a test document. Helper method to create a test document.
@ -140,13 +145,14 @@ class TestDisableSegmentFromIndexTask:
tokens=500, tokens=500,
completed_at=datetime.now(UTC), completed_at=datetime.now(UTC),
) )
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
return document return document
def _create_test_segment( def _create_test_segment(
self, self,
db_session_with_containers: Session,
document: Document, document: Document,
dataset: Dataset, dataset: Dataset,
tenant: Tenant, tenant: Tenant,
@ -185,12 +191,12 @@ class TestDisableSegmentFromIndexTask:
created_by=account.id, created_by=account.id,
completed_at=datetime.now(UTC) if status == "completed" else None, completed_at=datetime.now(UTC) if status == "completed" else None,
) )
db.session.add(segment) db_session_with_containers.add(segment)
db.session.commit() db_session_with_containers.commit()
return segment return segment
def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test successful segment disabling from index. Test successful segment disabling from index.
@ -202,9 +208,9 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Set up Redis cache # Set up Redis cache
indexing_cache_key = f"segment_{segment.id}_indexing" indexing_cache_key = f"segment_{segment.id}_indexing"
@ -226,10 +232,10 @@ class TestDisableSegmentFromIndexTask:
assert redis_client.get(indexing_cache_key) is None assert redis_client.get(indexing_cache_key) is None
# Verify segment is still in database # Verify segment is still in database
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.id is not None assert segment.id is not None
def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when segment is not found. Test handling when segment is not found.
@ -251,7 +257,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when segment is not in completed status. Test handling when segment is not in completed status.
@ -262,9 +268,11 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data with non-completed segment # Arrange: Create test data with non-completed segment
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) segment = self._create_test_segment(
db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True
)
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -275,7 +283,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when segment has no associated dataset. Test handling when segment has no associated dataset.
@ -286,13 +294,13 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Manually remove dataset association # Manually remove dataset association
segment.dataset_id = "00000000-0000-0000-0000-000000000000" segment.dataset_id = "00000000-0000-0000-0000-000000000000"
db.session.commit() db_session_with_containers.commit()
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -303,7 +311,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when segment has no associated document. Test handling when segment has no associated document.
@ -314,13 +322,13 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Manually remove document association # Manually remove document association
segment.document_id = "00000000-0000-0000-0000-000000000000" segment.document_id = "00000000-0000-0000-0000-000000000000"
db.session.commit() db_session_with_containers.commit()
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -331,7 +339,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when document is disabled. Test handling when document is disabled.
@ -342,12 +350,12 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data with disabled document # Arrange: Create test data with disabled document
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
document.enabled = False document.enabled = False
db.session.commit() db_session_with_containers.commit()
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -358,7 +366,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when document is archived. Test handling when document is archived.
@ -369,12 +377,12 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data with archived document # Arrange: Create test data with archived document
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
document.archived = True document.archived = True
db.session.commit() db_session_with_containers.commit()
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -385,7 +393,9 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): def test_disable_segment_document_indexing_not_completed(
self, db_session_with_containers: Session, mock_index_processor
):
""" """
Test handling when document indexing is not completed. Test handling when document indexing is not completed.
@ -396,12 +406,12 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data with incomplete indexing # Arrange: Create test data with incomplete indexing
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
document.indexing_status = "indexing" document.indexing_status = "indexing"
db.session.commit() db_session_with_containers.commit()
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -412,7 +422,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called # Verify index processor was not called
mock_index_processor.clean.assert_not_called() mock_index_processor.clean.assert_not_called()
def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test handling when index processor raises an exception. Test handling when index processor raises an exception.
@ -424,9 +434,9 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Set up Redis cache # Set up Redis cache
indexing_cache_key = f"segment_{segment.id}_indexing" indexing_cache_key = f"segment_{segment.id}_indexing"
@ -449,13 +459,13 @@ class TestDisableSegmentFromIndexTask:
assert call_args[0][1] == [segment.index_node_id] # Check index node IDs assert call_args[0][1] == [segment.index_node_id] # Check index node IDs
# Verify segment was re-enabled # Verify segment was re-enabled
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is True assert segment.enabled is True
# Verify Redis cache was still cleared # Verify Redis cache was still cleared
assert redis_client.get(indexing_cache_key) is None assert redis_client.get(indexing_cache_key) is None
def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test disabling segments with different document forms. Test disabling segments with different document forms.
@ -470,9 +480,11 @@ class TestDisableSegmentFromIndexTask:
for doc_form in doc_forms: for doc_form in doc_forms:
# Arrange: Create test data for each form # Arrange: Create test data for each form
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) document = self._create_test_document(
segment = self._create_test_segment(document, dataset, tenant, account) db_session_with_containers, dataset, tenant, account, doc_form=doc_form
)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Reset mock for each iteration # Reset mock for each iteration
mock_index_processor.reset_mock() mock_index_processor.reset_mock()
@ -489,7 +501,7 @@ class TestDisableSegmentFromIndexTask:
assert call_args[0][0].id == dataset.id # Check dataset ID assert call_args[0][0].id == dataset.id # Check dataset ID
assert call_args[0][1] == [segment.index_node_id] # Check index node IDs assert call_args[0][1] == [segment.index_node_id] # Check index node IDs
def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test Redis cache handling during segment disabling. Test Redis cache handling during segment disabling.
@ -500,9 +512,9 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Test with cache present # Test with cache present
indexing_cache_key = f"segment_{segment.id}_indexing" indexing_cache_key = f"segment_{segment.id}_indexing"
@ -517,13 +529,13 @@ class TestDisableSegmentFromIndexTask:
assert redis_client.get(indexing_cache_key) is None assert redis_client.get(indexing_cache_key) is None
# Test with no cache present # Test with no cache present
segment2 = self._create_test_segment(document, dataset, tenant, account) segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
result2 = disable_segment_from_index_task(segment2.id) result2 = disable_segment_from_index_task(segment2.id)
# Assert: Verify task still works without cache # Assert: Verify task still works without cache
assert result2 is None assert result2 is None
def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test performance timing of segment disabling task. Test performance timing of segment disabling task.
@ -534,9 +546,9 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task and measure time # Act: Execute the task and measure time
start_time = time.perf_counter() start_time = time.perf_counter()
@ -548,7 +560,9 @@ class TestDisableSegmentFromIndexTask:
execution_time = end_time - start_time execution_time = end_time - start_time
assert execution_time < 5.0 # Should complete within 5 seconds assert execution_time < 5.0 # Should complete within 5 seconds
def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): def test_disable_segment_database_session_management(
self, db_session_with_containers: Session, mock_index_processor
):
""" """
Test database session management during task execution. Test database session management during task execution.
@ -559,9 +573,9 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create test data # Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task # Act: Execute the task
result = disable_segment_from_index_task(segment.id) result = disable_segment_from_index_task(segment.id)
@ -570,10 +584,10 @@ class TestDisableSegmentFromIndexTask:
assert result is None assert result is None
# Verify segment is still accessible (session was properly managed) # Verify segment is still accessible (session was properly managed)
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.id is not None assert segment.id is not None
def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor):
""" """
Test concurrent execution of segment disabling tasks. Test concurrent execution of segment disabling tasks.
@ -584,12 +598,12 @@ class TestDisableSegmentFromIndexTask:
""" """
# Arrange: Create multiple test segments # Arrange: Create multiple test segments
account, tenant = self._create_test_account_and_tenant(db_session_with_containers) account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account) dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(dataset, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segments = [] segments = []
for i in range(3): for i in range(3):
segment = self._create_test_segment(document, dataset, tenant, account) segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
segments.append(segment) segments.append(segment)
# Act: Execute tasks concurrently (simulated) # Act: Execute tasks concurrently (simulated)

View File

@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Dataset, DocumentSegment from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument from models import Document as DatasetDocument
@ -31,7 +32,7 @@ class TestDisableSegmentsFromIndexTask:
and realistic testing environment with actual database interactions. and realistic testing environment with actual database interactions.
""" """
def _create_test_account(self, db_session_with_containers, fake=None): def _create_test_account(self, db_session_with_containers: Session, fake=None):
""" """
Helper method to create a test account with realistic data. Helper method to create a test account with realistic data.
@ -79,7 +80,7 @@ class TestDisableSegmentsFromIndexTask:
return account return account
def _create_test_dataset(self, db_session_with_containers, account, fake=None): def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None):
""" """
Helper method to create a test dataset with realistic data. Helper method to create a test dataset with realistic data.
@ -113,7 +114,7 @@ class TestDisableSegmentsFromIndexTask:
return dataset return dataset
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None):
""" """
Helper method to create a test document with realistic data. Helper method to create a test document with realistic data.
@ -158,7 +159,9 @@ class TestDisableSegmentsFromIndexTask:
return document return document
def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): def _create_test_segments(
self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None
):
""" """
Helper method to create test document segments with realistic data. Helper method to create test document segments with realistic data.
@ -210,7 +213,7 @@ class TestDisableSegmentsFromIndexTask:
return segments return segments
def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None):
""" """
Helper method to create a dataset process rule. Helper method to create a dataset process rule.
@ -239,14 +242,12 @@ class TestDisableSegmentsFromIndexTask:
process_rule.created_by = dataset.created_by process_rule.created_by = dataset.created_by
process_rule.updated_by = dataset.updated_by process_rule.updated_by = dataset.updated_by
from extensions.ext_database import db db_session_with_containers.add(process_rule)
db_session_with_containers.commit()
db.session.add(process_rule)
db.session.commit()
return process_rule return process_rule
def test_disable_segments_success(self, db_session_with_containers): def test_disable_segments_success(self, db_session_with_containers: Session):
""" """
Test successful disabling of segments from index. Test successful disabling of segments from index.
@ -297,7 +298,7 @@ class TestDisableSegmentsFromIndexTask:
expected_key = f"segment_{segment.id}_indexing" expected_key = f"segment_{segment.id}_indexing"
mock_redis.delete.assert_any_call(expected_key) mock_redis.delete.assert_any_call(expected_key)
def test_disable_segments_dataset_not_found(self, db_session_with_containers): def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session):
""" """
Test handling when dataset is not found. Test handling when dataset is not found.
@ -320,7 +321,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when dataset is not found # Redis should not be called when dataset is not found
mock_redis.delete.assert_not_called() mock_redis.delete.assert_not_called()
def test_disable_segments_document_not_found(self, db_session_with_containers): def test_disable_segments_document_not_found(self, db_session_with_containers: Session):
""" """
Test handling when document is not found. Test handling when document is not found.
@ -344,7 +345,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when document is not found # Redis should not be called when document is not found
mock_redis.delete.assert_not_called() mock_redis.delete.assert_not_called()
def test_disable_segments_document_invalid_status(self, db_session_with_containers): def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session):
""" """
Test handling when document has invalid status for disabling. Test handling when document has invalid status for disabling.
@ -360,9 +361,8 @@ class TestDisableSegmentsFromIndexTask:
# Test case 1: Document not enabled # Test case 1: Document not enabled
document.enabled = False document.enabled = False
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
segment_ids = [segment.id for segment in segments] segment_ids = [segment.id for segment in segments]
@ -379,7 +379,7 @@ class TestDisableSegmentsFromIndexTask:
# Test case 2: Document archived # Test case 2: Document archived
document.enabled = True document.enabled = True
document.archived = True document.archived = True
db.session.commit() db_session_with_containers.commit()
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act # Act
@ -393,7 +393,7 @@ class TestDisableSegmentsFromIndexTask:
document.enabled = True document.enabled = True
document.archived = False document.archived = False
document.indexing_status = "indexing" document.indexing_status = "indexing"
db.session.commit() db_session_with_containers.commit()
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act # Act
@ -403,7 +403,7 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value assert result is None # Task should complete without returning a value
mock_redis.delete.assert_not_called() mock_redis.delete.assert_not_called()
def test_disable_segments_no_segments_found(self, db_session_with_containers): def test_disable_segments_no_segments_found(self, db_session_with_containers: Session):
""" """
Test handling when no segments are found for the given IDs. Test handling when no segments are found for the given IDs.
@ -430,7 +430,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when no segments are found # Redis should not be called when no segments are found
mock_redis.delete.assert_not_called() mock_redis.delete.assert_not_called()
def test_disable_segments_index_processor_error(self, db_session_with_containers): def test_disable_segments_index_processor_error(self, db_session_with_containers: Session):
""" """
Test handling when index processor encounters an error. Test handling when index processor encounters an error.
@ -464,13 +464,14 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value assert result is None # Task should complete without returning a value
# Verify segments were rolled back to enabled state # Verify segments were rolled back to enabled state
from extensions.ext_database import db
db.session.refresh(segments[0]) db_session_with_containers.refresh(segments[0])
db.session.refresh(segments[1]) db_session_with_containers.refresh(segments[1])
# Check that segments are re-enabled after error # Check that segments are re-enabled after error
updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() updated_segments = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
)
for segment in updated_segments: for segment in updated_segments:
assert segment.enabled is True assert segment.enabled is True
@ -480,7 +481,7 @@ class TestDisableSegmentsFromIndexTask:
# Verify Redis cache cleanup was still called # Verify Redis cache cleanup was still called
assert mock_redis.delete.call_count == len(segments) assert mock_redis.delete.call_count == len(segments)
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session):
""" """
Test disabling segments with different document forms. Test disabling segments with different document forms.
@ -503,9 +504,8 @@ class TestDisableSegmentsFromIndexTask:
for doc_form in doc_forms: for doc_form in doc_forms:
# Update document form # Update document form
document.doc_form = doc_form document.doc_form = doc_form
from extensions.ext_database import db
db.session.commit() db_session_with_containers.commit()
# Mock the index processor factory # Mock the index processor factory
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
@ -523,7 +523,7 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value assert result is None # Task should complete without returning a value
mock_factory.assert_called_with(doc_form) mock_factory.assert_called_with(doc_form)
def test_disable_segments_performance_timing(self, db_session_with_containers): def test_disable_segments_performance_timing(self, db_session_with_containers: Session):
""" """
Test that the task properly measures and logs performance timing. Test that the task properly measures and logs performance timing.
@ -568,7 +568,7 @@ class TestDisableSegmentsFromIndexTask:
assert performance_log is not None assert performance_log is not None
assert "0.5" in performance_log # Should log the execution time assert "0.5" in performance_log # Should log the execution time
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session):
""" """
Test that Redis cache is properly cleaned up for all segments. Test that Redis cache is properly cleaned up for all segments.
@ -610,7 +610,7 @@ class TestDisableSegmentsFromIndexTask:
for expected_key in expected_keys: for expected_key in expected_keys:
assert expected_key in actual_calls assert expected_key in actual_calls
def test_disable_segments_database_session_cleanup(self, db_session_with_containers): def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session):
""" """
Test that database session is properly closed after task execution. Test that database session is properly closed after task execution.
@ -643,7 +643,7 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value assert result is None # Task should complete without returning a value
# Session lifecycle is managed by context manager; no explicit close assertion # Session lifecycle is managed by context manager; no explicit close assertion
def test_disable_segments_empty_segment_ids(self, db_session_with_containers): def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session):
""" """
Test handling when empty segment IDs list is provided. Test handling when empty segment IDs list is provided.
@ -669,7 +669,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when no segments are provided # Redis should not be called when no segments are provided
mock_redis.delete.assert_not_called() mock_redis.delete.assert_not_called()
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session):
""" """
Test handling when some segment IDs are valid and others are invalid. Test handling when some segment IDs are valid and others are invalid.

View File

@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -31,7 +31,9 @@ class TestEnableSegmentsToIndexTask:
"index_processor": mock_processor, "index_processor": mock_processor,
} }
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_dataset_and_document(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Helper method to create a test dataset and document for testing. Helper method to create a test dataset and document for testing.
@ -51,15 +53,15 @@ class TestEnableSegmentsToIndexTask:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -68,8 +70,8 @@ class TestEnableSegmentsToIndexTask:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Create dataset # Create dataset
dataset = Dataset( dataset = Dataset(
@ -81,8 +83,8 @@ class TestEnableSegmentsToIndexTask:
indexing_technique="high_quality", indexing_technique="high_quality",
created_by=account.id, created_by=account.id,
) )
db.session.add(dataset) db_session_with_containers.add(dataset)
db.session.commit() db_session_with_containers.commit()
# Create document # Create document
document = Document( document = Document(
@ -99,16 +101,16 @@ class TestEnableSegmentsToIndexTask:
enabled=True, enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_form=IndexStructureType.PARAGRAPH_INDEX,
) )
db.session.add(document) db_session_with_containers.add(document)
db.session.commit() db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property works correctly # Refresh dataset to ensure doc_form property works correctly
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
return dataset, document return dataset, document
def _create_test_segments( def _create_test_segments(
self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed" self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed"
): ):
""" """
Helper method to create test document segments. Helper method to create test document segments.
@ -144,14 +146,14 @@ class TestEnableSegmentsToIndexTask:
status=status, status=status,
created_by=document.created_by, created_by=document.created_by,
) )
db.session.add(segment) db_session_with_containers.add(segment)
segments.append(segment) segments.append(segment)
db.session.commit() db_session_with_containers.commit()
return segments return segments
def test_enable_segments_to_index_with_different_index_type( def test_enable_segments_to_index_with_different_index_type(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test segments indexing with different index types. Test segments indexing with different index types.
@ -169,10 +171,10 @@ class TestEnableSegmentsToIndexTask:
# Update document to use different index type # Update document to use different index type
document.doc_form = IndexStructureType.QA_INDEX document.doc_form = IndexStructureType.QA_INDEX
db.session.commit() db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document # Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
# Create segments # Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset) segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -204,7 +206,7 @@ class TestEnableSegmentsToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_enable_segments_to_index_dataset_not_found( def test_enable_segments_to_index_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling of non-existent dataset. Test handling of non-existent dataset.
@ -229,7 +231,7 @@ class TestEnableSegmentsToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called() mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_document_not_found( def test_enable_segments_to_index_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling of non-existent document. Test handling of non-existent document.
@ -256,7 +258,7 @@ class TestEnableSegmentsToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called() mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_invalid_document_status( def test_enable_segments_to_index_invalid_document_status(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling of document with invalid status. Test handling of document with invalid status.
@ -284,12 +286,12 @@ class TestEnableSegmentsToIndexTask:
document.enabled = True document.enabled = True
document.archived = False document.archived = False
document.indexing_status = "completed" document.indexing_status = "completed"
db.session.commit() db_session_with_containers.commit()
# Set invalid status # Set invalid status
for attr, value in status_attrs.items(): for attr, value in status_attrs.items():
setattr(document, attr, value) setattr(document, attr, value)
db.session.commit() db_session_with_containers.commit()
# Create segments # Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset) segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -304,11 +306,11 @@ class TestEnableSegmentsToIndexTask:
# Clean up segments for next iteration # Clean up segments for next iteration
for segment in segments: for segment in segments:
db.session.delete(segment) db_session_with_containers.delete(segment)
db.session.commit() db_session_with_containers.commit()
def test_enable_segments_to_index_segments_not_found( def test_enable_segments_to_index_segments_not_found(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test handling when no segments are found. Test handling when no segments are found.
@ -338,7 +340,7 @@ class TestEnableSegmentsToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called() mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_with_parent_child_structure( def test_enable_segments_to_index_with_parent_child_structure(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test segments indexing with parent-child structure. Test segments indexing with parent-child structure.
@ -357,10 +359,10 @@ class TestEnableSegmentsToIndexTask:
# Update document to use parent-child index type # Update document to use parent-child index type
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit() db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document # Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset) db_session_with_containers.refresh(dataset)
# Create segments with mock child chunks # Create segments with mock child chunks
segments = self._create_test_segments(db_session_with_containers, document, dataset) segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -410,7 +412,7 @@ class TestEnableSegmentsToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0 assert redis_client.exists(indexing_cache_key) == 0
def test_enable_segments_to_index_general_exception_handling( def test_enable_segments_to_index_general_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test general exception handling during indexing process. Test general exception handling during indexing process.
@ -443,7 +445,7 @@ class TestEnableSegmentsToIndexTask:
# Assert: Verify error handling # Assert: Verify error handling
for segment in segments: for segment in segments:
db.session.refresh(segment) db_session_with_containers.refresh(segment)
assert segment.enabled is False assert segment.enabled is False
assert segment.status == "error" assert segment.status == "error"
assert segment.error is not None assert segment.error is not None

View File

@ -2,8 +2,8 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.email_i18n import EmailType from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
@ -30,7 +30,7 @@ class TestMailAccountDeletionTask:
"email_service": mock_email_service, "email_service": mock_email_service,
} }
def _create_test_account(self, db_session_with_containers): def _create_test_account(self, db_session_with_containers: Session):
""" """
Helper method to create a test account for testing. Helper method to create a test account for testing.
@ -49,16 +49,16 @@ class TestMailAccountDeletionTask:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
# Create tenant # Create tenant
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -67,12 +67,14 @@ class TestMailAccountDeletionTask:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
return account return account
def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies): def test_send_deletion_success_task_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
""" """
Test successful account deletion success email sending. Test successful account deletion success email sending.
@ -109,7 +111,7 @@ class TestMailAccountDeletionTask:
) )
def test_send_deletion_success_task_mail_not_initialized( def test_send_deletion_success_task_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test account deletion success email when mail service is not initialized. Test account deletion success email when mail service is not initialized.
@ -132,7 +134,7 @@ class TestMailAccountDeletionTask:
mock_external_service_dependencies["email_service"].send_email.assert_not_called() mock_external_service_dependencies["email_service"].send_email.assert_not_called()
def test_send_deletion_success_task_email_service_exception( def test_send_deletion_success_task_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test account deletion success email when email service raises exception. Test account deletion success email when email service raises exception.
@ -154,7 +156,7 @@ class TestMailAccountDeletionTask:
mock_external_service_dependencies["email_service"].send_email.assert_called_once() mock_external_service_dependencies["email_service"].send_email.assert_called_once()
def test_send_account_deletion_verification_code_success( def test_send_account_deletion_verification_code_success(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test successful account deletion verification code email sending. Test successful account deletion verification code email sending.
@ -193,7 +195,7 @@ class TestMailAccountDeletionTask:
) )
def test_send_account_deletion_verification_code_mail_not_initialized( def test_send_account_deletion_verification_code_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test account deletion verification code email when mail service is not initialized. Test account deletion verification code email when mail service is not initialized.
@ -217,7 +219,7 @@ class TestMailAccountDeletionTask:
mock_external_service_dependencies["email_service"].send_email.assert_not_called() mock_external_service_dependencies["email_service"].send_email.assert_not_called()
def test_send_account_deletion_verification_code_email_service_exception( def test_send_account_deletion_verification_code_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
): ):
""" """
Test account deletion verification code email when email service raises exception. Test account deletion verification code email when email service raises exception.

View File

@ -4,11 +4,11 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Pipeline from models.dataset import Pipeline
from models.workflow import Workflow from models.workflow import Workflow
@ -52,7 +52,7 @@ class TestRagPipelineRunTasks:
"delete_file": mock_delete_file, "delete_file": mock_delete_file,
} }
def _create_test_pipeline_and_workflow(self, db_session_with_containers): def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session):
""" """
Helper method to create test pipeline and workflow for testing. Helper method to create test pipeline and workflow for testing.
@ -71,15 +71,15 @@ class TestRagPipelineRunTasks:
interface_language="en-US", interface_language="en-US",
status="active", status="active",
) )
db.session.add(account) db_session_with_containers.add(account)
db.session.commit() db_session_with_containers.commit()
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
status="normal", status="normal",
) )
db.session.add(tenant) db_session_with_containers.add(tenant)
db.session.commit() db_session_with_containers.commit()
# Create tenant-account join # Create tenant-account join
join = TenantAccountJoin( join = TenantAccountJoin(
@ -88,8 +88,8 @@ class TestRagPipelineRunTasks:
role=TenantAccountRole.OWNER, role=TenantAccountRole.OWNER,
current=True, current=True,
) )
db.session.add(join) db_session_with_containers.add(join)
db.session.commit() db_session_with_containers.commit()
# Create workflow # Create workflow
workflow = Workflow( workflow = Workflow(
@ -107,8 +107,8 @@ class TestRagPipelineRunTasks:
conversation_variables=[], conversation_variables=[],
rag_pipeline_variables=[], rag_pipeline_variables=[],
) )
db.session.add(workflow) db_session_with_containers.add(workflow)
db.session.commit() db_session_with_containers.commit()
# Create pipeline # Create pipeline
pipeline = Pipeline( pipeline = Pipeline(
@ -119,14 +119,14 @@ class TestRagPipelineRunTasks:
created_by=account.id, created_by=account.id,
) )
pipeline.id = str(uuid.uuid4()) pipeline.id = str(uuid.uuid4())
db.session.add(pipeline) db_session_with_containers.add(pipeline)
db.session.commit() db_session_with_containers.commit()
# Refresh entities to ensure they're properly loaded # Refresh entities to ensure they're properly loaded
db.session.refresh(account) db_session_with_containers.refresh(account)
db.session.refresh(tenant) db_session_with_containers.refresh(tenant)
db.session.refresh(workflow) db_session_with_containers.refresh(workflow)
db.session.refresh(pipeline) db_session_with_containers.refresh(pipeline)
return account, tenant, pipeline, workflow return account, tenant, pipeline, workflow
@ -209,7 +209,7 @@ class TestRagPipelineRunTasks:
return json.dumps(entities_data) return json.dumps(entities_data)
def test_priority_rag_pipeline_run_task_success( def test_priority_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test successful priority RAG pipeline run task execution. Test successful priority RAG pipeline run task execution.
@ -254,7 +254,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_rag_pipeline_run_task_success( def test_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test successful regular RAG pipeline run task execution. Test successful regular RAG pipeline run task execution.
@ -299,7 +299,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_priority_rag_pipeline_run_task_with_waiting_tasks( def test_priority_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test priority RAG pipeline run task with waiting tasks in queue using real Redis. Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
@ -351,7 +351,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
def test_rag_pipeline_run_task_legacy_compatibility( def test_rag_pipeline_run_task_legacy_compatibility(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
@ -419,7 +419,7 @@ class TestRagPipelineRunTasks:
redis_client.delete(legacy_task_key) redis_client.delete(legacy_task_key)
def test_rag_pipeline_run_task_with_waiting_tasks( def test_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test regular RAG pipeline run task with waiting tasks in queue using real Redis. Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
@ -469,7 +469,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
def test_priority_rag_pipeline_run_task_error_handling( def test_priority_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test error handling in priority RAG pipeline run task using real Redis. Test error handling in priority RAG pipeline run task using real Redis.
@ -526,7 +526,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 0 assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_error_handling( def test_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test error handling in regular RAG pipeline run task using real Redis. Test error handling in regular RAG pipeline run task using real Redis.
@ -581,7 +581,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 0 assert len(remaining_tasks) == 0
def test_priority_rag_pipeline_run_task_tenant_isolation( def test_priority_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test tenant isolation in priority RAG pipeline run task using real Redis. Test tenant isolation in priority RAG pipeline run task using real Redis.
@ -648,7 +648,7 @@ class TestRagPipelineRunTasks:
assert queue1._task_key != queue2._task_key assert queue1._task_key != queue2._task_key
def test_rag_pipeline_run_task_tenant_isolation( def test_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test tenant isolation in regular RAG pipeline run task using real Redis. Test tenant isolation in regular RAG pipeline run task using real Redis.
@ -713,7 +713,7 @@ class TestRagPipelineRunTasks:
assert queue1._task_key != queue2._task_key assert queue1._task_key != queue2._task_key
def test_run_single_rag_pipeline_task_success( def test_run_single_rag_pipeline_task_success(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
): ):
""" """
Test successful run_single_rag_pipeline_task execution. Test successful run_single_rag_pipeline_task execution.
@ -748,7 +748,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_run_single_rag_pipeline_task_entity_validation_error( def test_run_single_rag_pipeline_task_entity_validation_error(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
): ):
""" """
Test run_single_rag_pipeline_task with invalid entity data. Test run_single_rag_pipeline_task with invalid entity data.
@ -793,7 +793,7 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called() mock_pipeline_generator.assert_not_called()
def test_run_single_rag_pipeline_task_database_entity_not_found( def test_run_single_rag_pipeline_task_database_entity_not_found(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
): ):
""" """
Test run_single_rag_pipeline_task with non-existent database entities. Test run_single_rag_pipeline_task with non-existent database entities.
@ -838,7 +838,7 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called() mock_pipeline_generator.assert_not_called()
def test_priority_rag_pipeline_run_task_file_not_found( def test_priority_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test priority RAG pipeline run task with non-existent file. Test priority RAG pipeline run task with non-existent file.
@ -888,7 +888,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 0 assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_file_not_found( def test_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
): ):
""" """
Test regular RAG pipeline run task with non-existent file. Test regular RAG pipeline run task with non-existent file.